From 486feef242c15f96ed3df061cf92d93c22892987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Wed, 7 Feb 2024 04:12:59 +0100 Subject: [PATCH] refacto: align data api with edsnlp --- docs/trainable-pipes.md | 6 +- edspdf/__init__.py | 1 + edspdf/accelerators/base.py | 97 +- edspdf/accelerators/multiprocessing.py | 522 +--------- edspdf/accelerators/simple.py | 92 -- edspdf/lazy_collection.py | 325 +++++++ edspdf/pipeline.py | 405 +++++--- edspdf/pipes/classifiers/trainable.py | 17 +- .../pipes/embeddings/box_layout_embedding.py | 6 +- .../embeddings/box_layout_preprocessor.py | 3 +- .../pipes/embeddings/huggingface_embedding.py | 31 +- .../pipes/embeddings/simple_text_embedding.py | 11 +- edspdf/processing/__init__.py | 9 + edspdf/processing/multiprocessing.py | 893 ++++++++++++++++++ edspdf/processing/simple.py | 68 ++ edspdf/registry.py | 2 + edspdf/trainable_pipe.py | 276 ++++-- edspdf/utils/collections.py | 141 ++- edspdf/utils/lazy_module.py | 108 +++ edspdf/utils/package.py | 97 +- pyproject.toml | 6 +- tests/core/test_data.py | 66 ++ tests/core/test_pipeline.py | 38 +- tests/utils/test_package.py | 8 +- 24 files changed, 2169 insertions(+), 1059 deletions(-) delete mode 100644 edspdf/accelerators/simple.py create mode 100644 edspdf/lazy_collection.py create mode 100644 edspdf/processing/__init__.py create mode 100644 edspdf/processing/multiprocessing.py create mode 100644 edspdf/processing/simple.py create mode 100644 edspdf/utils/lazy_module.py create mode 100644 tests/core/test_data.py diff --git a/docs/trainable-pipes.md b/docs/trainable-pipes.md index b9dc02c8..ae6f37b0 100644 --- a/docs/trainable-pipes.md +++ b/docs/trainable-pipes.md @@ -97,12 +97,12 @@ class MyComponent(TrainablePipe): "my-feature": ...(doc), } - def collate(self, batch, device: torch.device) -> Dict: + def collate(self, batch) -> Dict: # Collate the features of the "embedding" subcomponent # and the features of this component as well return { - "embedding": self.embedding.collate(batch["embedding"], device), - "my-feature": torch.as_tensor(batch["my-feature"], device=device), + "embedding": self.embedding.collate(batch["embedding"]), + "my-feature": torch.as_tensor(batch["my-feature"]), } def forward(self, batch: Dict, supervision=False) -> Dict: diff --git a/edspdf/__init__.py b/edspdf/__init__.py index 9d5ed750..2c0ccc2e 100644 --- a/edspdf/__init__.py +++ b/edspdf/__init__.py @@ -3,6 +3,7 @@ from .pipeline import Pipeline, load from .registry import registry from .structures import Box, Page, PDFDoc, Text, TextBox, TextProperties +from . import data from . import utils # isort:skip diff --git a/edspdf/accelerators/base.py b/edspdf/accelerators/base.py index 07bc01b2..2360377d 100644 --- a/edspdf/accelerators/base.py +++ b/edspdf/accelerators/base.py @@ -1,97 +1,2 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Union - -from ..structures import PDFDoc - - -class FromDictFieldsToDoc: - def __init__(self, content_field, id_field=None): - self.content_field = content_field - self.id_field = id_field - - def __call__(self, item): - if isinstance(item, dict): - return PDFDoc( - content=item[self.content_field], - id=item[self.id_field] if self.id_field else None, - ) - return item - - -class ToDoc: - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, value, config=None): - if isinstance(value, str): - value = {"content_field": value} - if isinstance(value, dict): - value = FromDictFieldsToDoc(**value) - if callable(value): - return value - raise TypeError( - f"Invalid entry {value} ({type(value)}) for ToDoc, " - f"expected string, a dict or a callable." - ) - - -FROM_DOC_TO_DICT_FIELDS_TEMPLATE = """ -def fn(doc): - return {X} -""" - - -class FromDocToDictFields: - def __init__(self, mapping): - self.mapping = mapping - dict_fields = ", ".join(f"{repr(k)}: doc.{v}" for k, v in mapping.items()) - local_vars = {} - exec(FROM_DOC_TO_DICT_FIELDS_TEMPLATE.replace("X", dict_fields), local_vars) - self.fn = local_vars["fn"] - - def __reduce__(self): - return FromDocToDictFields, (self.mapping,) - - def __call__(self, doc): - return self.fn(doc) - - -class FromDoc: - """ - A FromDoc converter (from a PDFDoc to an arbitrary type) can be either: - - - a dict mapping field names to doc attributes - - a callable that takes a PDFDoc and returns an arbitrary type - """ - - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, value, config=None): - if isinstance(value, dict): - value = FromDocToDictFields(value) - if callable(value): - return value - raise TypeError( - f"Invalid entry {value} ({type(value)}) for ToDoc, " - f"expected dict or callable" - ) - - class Accelerator: - def __call__( - self, - inputs: Iterable[Any], - model: Any, - to_doc: ToDoc = FromDictFieldsToDoc("content"), - from_doc: FromDoc = lambda doc: doc, - ): - raise NotImplementedError() - - -if TYPE_CHECKING: - ToDoc = Union[str, Dict[str, Any], Callable[[Any], PDFDoc]] # noqa: F811 - FromDoc = Union[Dict[str, Any], Callable[[PDFDoc], Any]] # noqa: F811 + pass diff --git a/edspdf/accelerators/multiprocessing.py b/edspdf/accelerators/multiprocessing.py index 41ef2f80..b11d556e 100644 --- a/edspdf/accelerators/multiprocessing.py +++ b/edspdf/accelerators/multiprocessing.py @@ -1,338 +1,15 @@ -import gc -import signal -from multiprocessing.connection import wait -from random import shuffle -from typing import Any, Iterable, List, Optional, Union +from typing import List, Optional, Union import torch -import torch.multiprocessing as mp -from .. import TrainablePipe from ..registry import registry -from ..utils.collections import batchify -from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc - -DEBUG = False - -debug = ( - (lambda *args, flush=False, **kwargs: print(*args, **kwargs, flush=True)) - if DEBUG - else lambda *args, **kwargs: None -) - - -class Exchanger: - def __init__( - self, - num_stages, - num_gpu_workers, - num_cpu_workers, - gpu_worker_devices, - ): - # queue for cpu input tasks - self.gpu_worker_devices = gpu_worker_devices - # We add prioritized queue at the end for STOP signals - self.cpu_inputs_queues = [ - [mp.SimpleQueue()] + [mp.SimpleQueue() for _ in range(num_stages + 1)] - # The input queue is not shared between processes, since calling `wait` - # on a queue reader from multiple processes may lead to a deadlock - for _ in range(num_cpu_workers) - ] - self.gpu_inputs_queues = [ - [mp.SimpleQueue() for _ in range(num_stages + 1)] - for _ in range(num_gpu_workers) - ] - self.outputs_queue = mp.Queue() - - def get_cpu_tasks(self, idx): - while True: - queue_readers = wait( - [queue._reader for queue in self.cpu_inputs_queues[idx]] - ) - stage, queue = next( - (stage, q) - for stage, q in reversed(list(enumerate(self.cpu_inputs_queues[idx]))) - if q._reader in queue_readers - ) - try: - item = queue.get() - except BaseException: - continue - if item is None: - return - yield stage, item - - def put_cpu(self, item, stage, idx): - return self.cpu_inputs_queues[idx][stage].put(item) - - def get_gpu_tasks(self, idx): - while True: - queue_readers = wait( - [queue._reader for queue in self.gpu_inputs_queues[idx]] - ) - stage, queue = next( - (stage, q) - for stage, q in reversed(list(enumerate(self.gpu_inputs_queues[idx]))) - if q._reader in queue_readers - ) - try: - item = queue.get() - except BaseException: # pragma: no cover - continue - if item is None: - return - yield stage, item - - def put_gpu(self, item, stage, idx): - return self.gpu_inputs_queues[idx][stage].put(item) - - def put_results(self, items): - self.outputs_queue.put(items) - - def iter_results(self): - for out in iter(self.outputs_queue.get, None): - yield out - - -class CPUWorker(mp.Process): - def __init__( - self, - cpu_idx: int, - exchanger: Exchanger, - gpu_pipe_names: List[str], - model: Any, - device: Union[str, torch.device], - ): - super(CPUWorker, self).__init__() - - self.cpu_idx = cpu_idx - self.exchanger = exchanger - self.gpu_pipe_names = gpu_pipe_names - self.model = model - self.device = device - - def _run(self): - # Cannot pass torch tensor during init i think ? otherwise i get - # ValueError: bad value(s) in fds_to_keep - mp._prctl_pr_set_pdeathsig(signal.SIGINT) - - model = self.model.to(self.device) - stages = [{"cpu_components": [], "gpu_component": None}] - for name, component in model.pipeline: - if name in self.gpu_pipe_names: - stages[-1]["gpu_component"] = component - stages.append({"cpu_components": [], "gpu_component": None}) - else: - stages[-1]["cpu_components"].append(component) - - next_batch_id = 0 - active_batches = {} - debug( - f"CPU worker {self.cpu_idx} is ready", - next(model.parameters()).device, - flush=True, - ) - - had_error = False - with torch.no_grad(): - for stage, task in self.exchanger.get_cpu_tasks(self.cpu_idx): - if had_error: - continue # pragma: no cover - try: - if stage == 0: - gpu_idx = None - batch_id = next_batch_id - debug("preprocess start for", batch_id) - next_batch_id += 1 - docs = task - else: - gpu_idx, batch_id, result = task - debug("postprocess start for", batch_id) - docs = active_batches.pop(batch_id) - gpu_pipe = stages[stage - 1]["gpu_component"] - docs = gpu_pipe.postprocess(docs, result) # type: ignore - - for component in stages[stage]["cpu_components"]: - if hasattr(component, "batch_process"): - docs = component.batch_process(docs) - else: - docs = [component(doc) for doc in docs] - - gpu_pipe = stages[stage]["gpu_component"] - if gpu_pipe is not None: - preprocessed = gpu_pipe.make_batch(docs) # type: ignore - active_batches[batch_id] = docs - if gpu_idx is None: - gpu_idx = batch_id % len(self.exchanger.gpu_worker_devices) - collated = gpu_pipe.collate( # type: ignore - preprocessed, - device=self.exchanger.gpu_worker_devices[gpu_idx], - ) - self.exchanger.put_gpu( - item=(self.cpu_idx, batch_id, collated), - idx=gpu_idx, - stage=stage, - ) - batch_id += 1 - debug("preprocess end for", batch_id) - else: - self.exchanger.put_results((docs, self.cpu_idx, gpu_idx)) - debug("postprocess end for", batch_id) - except BaseException as e: - had_error = True - import traceback - - print(traceback.format_exc(), flush=True) - self.exchanger.put_results((e, self.cpu_idx, None)) - # We need to drain the queues of GPUWorker fed inputs (pre-moved to GPU) - # to ensure no tensor allocated on producer processes (CPUWorker via - # collate) are left in consumer processes - debug("Start draining CPU worker", self.cpu_idx) - [None for _ in self.exchanger.get_cpu_tasks(self.cpu_idx)] - debug(f"CPU worker {self.cpu_idx} is about to stop") - - def run(self): - self._run() - self.model = None - gc.collect() - torch.cuda.empty_cache() - - -class GPUWorker(mp.Process): - def __init__( - self, - gpu_idx, - exchanger: Exchanger, - gpu_pipe_names: List[str], - model: Any, - device: Union[str, torch.device], - ): - super().__init__() - - self.device = device - self.gpu_idx = gpu_idx - self.exchanger = exchanger - - self.gpu_pipe_names = gpu_pipe_names - self.model = model - self.device = device - - def _run(self): - debug("GPU worker", self.gpu_idx, "started") - mp._prctl_pr_set_pdeathsig(signal.SIGINT) - had_error = False - - model = self.model.to(self.device) - stage_components = [model.get_pipe(name) for name in self.gpu_pipe_names] - del model - with torch.no_grad(): - for stage, task in self.exchanger.get_gpu_tasks(self.gpu_idx): - if had_error: - continue # pragma: no cover - try: - cpu_idx, batch_id, batch = task - debug("forward start for", batch_id) - component = stage_components[stage] - res = component.module_forward(batch) - del batch, task - # TODO set non_blocking=True here - res = { - key: val.to("cpu") if not isinstance(val, int) else val - for key, val in res.items() - } - self.exchanger.put_cpu( - item=(self.gpu_idx, batch_id, res), - stage=stage + 1, - idx=cpu_idx, - ) - debug("forward end for", batch_id) - except BaseException as e: - had_error = True - self.exchanger.put_results((e, None, self.gpu_idx)) - import traceback - - print(traceback.format_exc(), flush=True) - task = batch = res = None # noqa - # We need to drain the queues of CPUWorker fed inputs (pre-moved to GPU) - # to ensure no tensor allocated on producer processes (CPUWorker via - # collate) are left in consumer processes - debug("Start draining GPU worker", self.gpu_idx) - [None for _ in self.exchanger.get_gpu_tasks(self.gpu_idx)] - debug(f"GPU worker {self.gpu_idx} is about to stop") - - def run(self): - self._run() - self.model = None - gc.collect() - torch.cuda.empty_cache() - - -DEFAULT_MAX_CPU_WORKERS = 4 +from .base import Accelerator @registry.accelerator.register("multiprocessing") class MultiprocessingAccelerator(Accelerator): """ - If you have multiple CPU cores, and optionally multiple GPUs, we provide a - `multiprocessing` accelerator that allows to run the inference on multiple - processes. - - This accelerator dispatches the batches between multiple workers - (data-parallelism), and distribute the computation of a given batch on one or two - workers (model-parallelism). This is done by creating two types of workers: - - - a `CPUWorker` which handles the non deep-learning components and the - preprocessing, collating and postprocessing of deep-learning components - - a `GPUWorker` which handles the forward call of the deep-learning components - - The advantage of dedicating a worker to the deep-learning components is that it - allows to prepare multiple batches in parallel in multiple `CPUWorker`, and ensure - that the `GPUWorker` never wait for a batch to be ready. - - The overall architecture described in the following figure, for 3 CPU workers and 2 - GPU workers. - -
- -
- - Here is how a small pipeline with rule-based components and deep-learning components - is distributed between the workers: - -
- -
- - Examples - -------- - - ```python - docs = list( - pipeline.pipe( - [content1, content2, ...], - accelerator={ - "@accelerator": "multiprocessing", - "num_cpu_workers": 3, - "num_gpu_workers": 2, - "batch_size": 8, - }, - ) - ) - ``` - - Parameters - ---------- - batch_size: int - Number of documents to process at a time in a CPU/GPU worker - num_cpu_workers: int - Number of CPU workers. A CPU worker handles the non deep-learning components - and the preprocessing, collating and postprocessing of deep-learning components. - num_gpu_workers: Optional[int] - Number of GPU workers. A GPU worker handles the forward call of the - deep-learning components. - gpu_pipe_names: Optional[List[str]] - List of pipe names to accelerate on a GPUWorker, defaults to all pipes - that inherit from TrainablePipe + Deprecated: Use `docs.map_pipeline(model).set_processing(...)` instead """ def __init__( @@ -350,196 +27,3 @@ def __init__( self.gpu_pipe_names = gpu_pipe_names self.gpu_worker_devices = gpu_worker_devices self.cpu_worker_devices = cpu_worker_devices - - def __call__( - self, - inputs: Iterable[Any], - model: Any, - to_doc: ToDoc = FromDictFieldsToDoc("content"), - from_doc: FromDoc = lambda doc: doc, - ): - """ - Stream of documents to process. Each document can be a string or a tuple - - Parameters - ---------- - inputs - model - - Yields - ------ - Any - Processed outputs of the pipeline - """ - if torch.multiprocessing.get_start_method() != "spawn": - torch.multiprocessing.set_start_method("spawn", force=True) - - gpu_pipe_names = ( - [ - name - for name, component in model.pipeline - if isinstance(component, TrainablePipe) - ] - if self.gpu_pipe_names is None - else self.gpu_pipe_names - ) - - if not all(model.has_pipe(name) for name in gpu_pipe_names): - raise ValueError( - "GPU accelerated pipes {} could not be found in the model".format( - sorted(set(model.pipe_names) - set(gpu_pipe_names)) - ) - ) - - num_devices = torch.cuda.device_count() - print(f"Number of available devices: {num_devices}", flush=True) - - num_cpu_workers = self.num_cpu_workers - num_gpu_workers = self.num_gpu_workers - - if num_gpu_workers is None: - num_gpu_workers = num_devices if len(gpu_pipe_names) > 0 else 0 - - if num_cpu_workers is None: - num_cpu_workers = max( - min(mp.cpu_count() - num_gpu_workers, DEFAULT_MAX_CPU_WORKERS), 0 - ) - - if num_gpu_workers == 0: - gpu_pipe_names = [] - - gpu_worker_devices = ( - [ - torch.device(f"cuda:{gpu_idx * num_devices // num_gpu_workers}") - for gpu_idx in range(num_gpu_workers) - ] - if self.gpu_worker_devices is None - else self.gpu_worker_devices - ) - cpu_worker_devices = ( - ["cpu"] * num_cpu_workers - if self.cpu_worker_devices is None - else self.cpu_worker_devices - ) - assert len(cpu_worker_devices) == num_cpu_workers - assert len(gpu_worker_devices) == num_gpu_workers - if num_cpu_workers == 0: - ( - num_cpu_workers, - num_gpu_workers, - cpu_worker_devices, - gpu_worker_devices, - gpu_pipe_names, - ) = (num_gpu_workers, 0, gpu_worker_devices, [], []) - - debug(f"Number of CPU workers: {num_cpu_workers}") - debug(f"Number of GPU workers: {num_gpu_workers}") - - exchanger = Exchanger( - num_stages=len(gpu_pipe_names), - num_cpu_workers=num_cpu_workers, - num_gpu_workers=num_gpu_workers, - gpu_worker_devices=gpu_worker_devices, - ) - - cpu_workers = [] - gpu_workers = [] - model = model.to("cpu") - - for gpu_idx in range(num_gpu_workers): - gpu_workers.append( - GPUWorker( - gpu_idx=gpu_idx, - exchanger=exchanger, - gpu_pipe_names=gpu_pipe_names, - model=model, - device=gpu_worker_devices[gpu_idx], - ) - ) - - for cpu_idx in range(num_cpu_workers): - cpu_workers.append( - CPUWorker( - cpu_idx=cpu_idx, - exchanger=exchanger, - gpu_pipe_names=gpu_pipe_names, - model=model, - device=cpu_worker_devices[cpu_idx], - ) - ) - - for worker in (*cpu_workers, *gpu_workers): - worker.start() - - try: - num_max_enqueued = num_cpu_workers * 2 + 10 - # Number of input/output batch per process - total_inputs = [0] * num_cpu_workers - total_outputs = [0] * num_cpu_workers - outputs_iterator = exchanger.iter_results() - - cpu_worker_indices = list(range(num_cpu_workers)) - inputs_iterator = (to_doc(i) for i in inputs) - for i, pdfs_batch in enumerate(batchify(inputs_iterator, self.batch_size)): - if sum(total_inputs) - sum(total_outputs) >= num_max_enqueued: - outputs, cpu_idx, gpu_idx = next(outputs_iterator) - if isinstance(outputs, BaseException): - raise outputs # pragma: no cover - yield from (from_doc(o) for o in outputs) - total_outputs[cpu_idx] += 1 - - # Shuffle to ensure the first process does not receive all the documents - # in case of total_inputs - total_outputs equality - shuffle(cpu_worker_indices) - cpu_idx = min( - cpu_worker_indices, - key=lambda i: total_inputs[i] - total_outputs[i], - ) - exchanger.put_cpu(pdfs_batch, stage=0, idx=cpu_idx) - total_inputs[cpu_idx] += 1 - - while sum(total_outputs) < sum(total_inputs): - outputs, cpu_idx, gpu_idx = next(outputs_iterator) - if isinstance(outputs, BaseException): - raise outputs # pragma: no cover - yield from (from_doc(o) for o in outputs) - total_outputs[cpu_idx] += 1 - finally: - # Send gpu and cpu process the order to stop processing data - # We use the prioritized queue to ensure the stop signal is processed - # before the next batch of data - for i, worker in enumerate(gpu_workers): - exchanger.gpu_inputs_queues[i][-1].put(None) - debug("Asked gpu worker", i, "to stop processing data") - for i, worker in enumerate(cpu_workers): - exchanger.cpu_inputs_queues[i][-1].put(None) - debug("Asked cpu worker", i, "to stop processing data") - - # Enqueue a final non prioritized STOP signal to ensure there remains no - # data in the queues (cf drain loop in CPUWorker / GPUWorker) - for i, worker in enumerate(gpu_workers): - exchanger.gpu_inputs_queues[i][0].put(None) - debug("Asked gpu", i, "to end") - for i, worker in enumerate(gpu_workers): - worker.join(timeout=5) - debug("Joined gpu worker", i) - for i, worker in enumerate(cpu_workers): - exchanger.cpu_inputs_queues[i][0].put(None) - debug("Asked cpu", i, "to end") - for i, worker in enumerate(cpu_workers): - worker.join(timeout=1) - debug("Joined cpu worker", i) - - # If a worker is still alive, kill it - # This should not happen, but for a reason I cannot explain, it does in - # some CPU workers sometimes when we catch an error, even though each run - # method of the workers completes cleanly. Maybe this has something to do - # with the cleanup of these processes ? - for i, worker in enumerate(gpu_workers): # pragma: no cover - if worker.is_alive(): - print("Killing gpu worker", i) - worker.kill() - for i, worker in enumerate(cpu_workers): # pragma: no cover - if worker.is_alive(): - print("Killing cpu worker", i) - worker.kill() diff --git a/edspdf/accelerators/simple.py b/edspdf/accelerators/simple.py deleted file mode 100644 index 340dbebb..00000000 --- a/edspdf/accelerators/simple.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Any, Dict, Iterable - -import torch - -from ..registry import registry -from ..utils.collections import batchify -from .base import Accelerator, FromDictFieldsToDoc, FromDoc, ToDoc - - -@registry.accelerator.register("simple") -class SimpleAccelerator(Accelerator): - """ - This is the simplest accelerator which batches the documents and process each batch - on the main process (the one calling `.pipe()`). - - Examples - -------- - - ```python - docs = list(pipeline.pipe([content1, content2, ...])) - ``` - - or, if you want to override the model defined batch size - - ```python - docs = list(pipeline.pipe([content1, content2, ...], batch_size=8)) - ``` - - which is equivalent to passing a confit dict - - ```python - docs = list( - pipeline.pipe( - [content1, content2, ...], - accelerator={ - "@accelerator": "simple", - "batch_size": 8, - }, - ) - ) - ``` - - or the instantiated accelerator directly - - ```python - from edspdf.accelerators.simple import SimpleAccelerator - - accelerator = SimpleAccelerator(batch_size=8) - docs = list(pipeline.pipe([content1, content2, ...], accelerator=accelerator)) - ``` - - If you have a GPU, make sure to move the model to the appropriate device before - calling `.pipe()`. If you have multiple GPUs, use the - [multiprocessing][edspdf.accelerators.multiprocessing.MultiprocessingAccelerator] - accelerator instead. - - ```python - pipeline.to("cuda") - docs = list(pipeline.pipe([content1, content2, ...])) - ``` - - Parameters - ---------- - batch_size: int - The number of documents to process in each batch. - """ - - def __init__( - self, - *, - batch_size: int = 32, - ): - self.batch_size = batch_size - - def __call__( - self, - inputs: Iterable[Any], - model: Any, - to_doc: ToDoc = FromDictFieldsToDoc("content"), - from_doc: FromDoc = lambda doc: doc, - component_cfg: Dict[str, Dict[str, Any]] = None, - ): - docs = (to_doc(doc) for doc in inputs) - for batch in batchify(docs, batch_size=self.batch_size): - with torch.no_grad(), model.cache(), model.train(False): - for name, pipe in model.pipeline: - if name not in model._disabled: - if hasattr(pipe, "batch_process"): - batch = pipe.batch_process(batch) - else: - batch = [pipe(doc) for doc in batch] # type: ignore - yield from (from_doc(doc) for doc in batch) diff --git a/edspdf/lazy_collection.py b/edspdf/lazy_collection.py new file mode 100644 index 00000000..bea7e7de --- /dev/null +++ b/edspdf/lazy_collection.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +import contextlib +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Container, + Dict, + Iterable, + List, + Optional, + Tuple, +) + +from typing_extensions import Literal + +import edspdf.data + +if TYPE_CHECKING: + import torch + + from edspdf import Pipeline + from edspdf.data.base import BaseReader, BaseWriter + from edspdf.trainable_pipe import TrainablePipe + +INFER = type("INFER", (), {"__repr__": lambda self: "INFER"})() + + +class MetaLazyCollection(type): + def __getattr__(self, item): + if item in edspdf.data.__all__: + fn = getattr(edspdf.data, item) + setattr(self, item, fn) + return fn + raise AttributeError(item) + + def __dir__(self): # pragma: no cover + return (*super().__dir__(), *edspdf.data.__all__) + + +class LazyCollection(metaclass=MetaLazyCollection): + def __init__( + self, + reader: Optional[BaseReader] = None, + writer: Optional[BaseWriter] = None, + pipeline: List[Any] = [], + config={}, + ): + self.reader = reader + self.writer = writer + self.pipeline: List[Tuple[str, Callable, Dict]] = pipeline + self.config = config + + @property + def batch_size(self): + return self.config.get("batch_size", 1) + + @property + def batch_unit(self): + return self.config.get("batch_unit", 1) + + @property + def sort_by_size(self): + return self.config.get("sort_by_size", False) + + @property + def chunk_size(self): + return self.config.get("chunk_size", 1024) + + @property + def num_cpu_workers(self): + return self.config.get("num_cpu_workers") + + @property + def num_gpu_workers(self): + return self.config.get("num_gpu_workers") + + @property + def gpu_pipe_names(self): + return self.config.get("gpu_pipe_names") + + @property + def gpu_worker_devices(self): + return self.config.get("gpu_worker_devices") + + @property + def cpu_worker_devices(self): + return self.config.get("cpu_worker_devices") + + @property + def backend(self): + return self.config.get("backend") + + @property + def show_progress(self): + return self.config.get("show_progress") + + @property + def process_start_method(self): + return self.config.get("process_start_method") + + def set_processing( + self, + batch_size: int = INFER, + batch_unit: Literal["docs", "pages", "lines"] = INFER, + chunk_size: int = INFER, + sort_by_size: bool = INFER, + num_cpu_workers: int = INFER, + num_gpu_workers: int = INFER, + backend: Literal["simple", "multiprocessing", "spark"] = INFER, + gpu_pipe_names: List[str] = INFER, + show_progress: bool = INFER, + process_start_method: bool = INFER, + gpu_worker_devices: List[str] = INFER, + cpu_worker_devices: List[str] = INFER, + ) -> "LazyCollection": + """ + Parameters + ---------- + batch_size: int + Number of documents to process at a time in a GPU worker (or in the + main process if no workers are used). + batch_unit: Literal["docs", "pages", "lines"] + The unit of the batch size. Can be "docs" or "words". If "words", the + batch size is total number of words in the documents. + chunk_size: int + Number of documents to build before splitting into batches. Only used + with "simple" and "multiprocessing" backends. This is also the number of + documents that will be passed through the first components of the pipeline + until a GPU worker is used (then the will be split according to the + `batch_size` and `batch_unit`). + sort_by_size: bool + Whether to sort the documents by size before splitting into batches. + num_cpu_workers: int + Number of CPU workers. A CPU worker handles the non deep-learning components + and the preprocessing, collating and postprocessing of deep-learning + components. If no GPU workers are used, the CPU workers also handle the + forward call of the deep-learning components. + num_gpu_workers: Optional[int] + Number of GPU workers. A GPU worker handles the forward call of the + deep-learning components. Only used with "multiprocessing" backend. + gpu_pipe_names: Optional[List[str]] + List of pipe names to accelerate on a GPUWorker, defaults to all pipes + that inherit from TrainablePipe. Only used with "multiprocessing" backend. + backend: Optional[Literal["simple", "multiprocessing", "spark"]] + The backend to use for parallel processing. If not set, the backend is + automatically selected based on the input data and the number of workers. + + - "simple" is the default backend and is used when `num_cpu_workers` is 1 + and `num_gpu_workers` is 0. + - "multiprocessing" is used when `num_cpu_workers` is greater than 1 or + `num_gpu_workers` is greater than 0. + - "spark" is used when the input data is a Spark dataframe and the output + writer is a Spark writer. + show_progress: Optional[bool] + Whether to show progress bars (only applicable with "simple" and + "multiprocessing" backends). + process_start_method: Optional[bool] + Whether to use "fork" or "spawn" as the start method for the multiprocessing + backend. + + - "fork" is the default start method on Unix systems and is the fastest + start method, but it is not available on Windows, can cause issues + with CUDA and is not safe when using multiple threads. + - "spawn" is the default start method on Windows and is the safest start + method, but it is not available on Unix systems and is slower than + "fork". + gpu_worker_devices: Optional[List[str]] + List of GPU devices to use for the GPU workers. Defaults to all available + devices, one worker per device. Only used with "multiprocessing" backend. + cpu_worker_devices: Optional[List[str]] + List of GPU devices to use for the CPU workers. Used for debugging purposes. + + Returns + ------- + LazyCollection + """ + kwargs = dict(locals()) + kwargs.pop("self") + return LazyCollection( + reader=self.reader, + writer=self.writer, + pipeline=self.pipeline, + config={ + **self.config, + **{k: v for k, v in kwargs.items() if v is not INFER}, + }, + ) + + @classmethod + def ensure_lazy(cls, data): + from edspdf.data.base import IterableReader + + if isinstance(data, cls): + return data + return cls(reader=IterableReader(data)) + + def map(self, pipe, name: Optional[str] = None, kwargs={}) -> "LazyCollection": + return LazyCollection( + reader=self.reader, + writer=self.writer, + pipeline=[*self.pipeline, (name, pipe, kwargs)], + config=self.config, + ) + + def map_pipeline(self, model: Pipeline) -> "LazyCollection": + new_steps = [] + for name, pipe, kwargs in self.pipeline: + new_steps.append((name, pipe, kwargs)) + + new_steps.append(("_ensure_doc", model.ensure_doc, {})) + + for name, pipe in model.pipeline: + if name not in model._disabled: + new_steps.append((name, pipe, {})) + config = ( + {**self.config, "batch_size": model.batch_size} + if self.batch_size is None + else self.config + ) + return LazyCollection( + reader=self.reader, + writer=self.writer, + pipeline=new_steps, + config=config, + ) + + def write(self, writer: BaseWriter, execute: bool = True) -> Any: + lc = LazyCollection( + reader=self.reader, + writer=writer, + pipeline=self.pipeline, + config=self.config, + ) + return lc.execute() if execute else lc + + def execute(self): + import edspdf.processing + + backend = self.backend + if backend is None: + try: + SparkReader = sys.modules.get("edspdf.data.spark").SparkReader + SparkWriter = sys.modules.get("edspdf.data.spark").SparkWriter + except (KeyError, AttributeError): # pragma: no cover + SparkReader = SparkWriter = None + if ( + SparkReader + and isinstance(self.reader, SparkReader) + and SparkWriter + and (self.writer is None or isinstance(self.writer, SparkWriter)) + ): + backend = "spark" + elif ( + self.num_cpu_workers is not None or self.num_gpu_workers is not None + ) and ( + self.num_cpu_workers is not None + and self.num_cpu_workers > 1 + or self.num_gpu_workers is not None + and self.num_gpu_workers > 0 + ): + backend = "multiprocessing" + else: + backend = "simple" + execute = getattr(edspdf.processing, f"execute_{backend}_backend") + return execute(self) + + def __iter__(self): + return iter(self.execute()) + + @contextlib.contextmanager + def cache(self): + for name, pipe, *_ in self.pipeline: + if hasattr(pipe, "enable_cache"): + pipe.enable_cache() + yield + for name, pipe, *_ in self.pipeline: + if hasattr(pipe, "disable_cache"): + pipe.disable_cache() + + def torch_components( + self, disable: Container[str] = () + ) -> Iterable[Tuple[str, "TrainablePipe"]]: + """ + Yields components that are PyTorch modules. + + Parameters + ---------- + disable: Container[str] + The names of disabled components, which will be skipped. + + Returns + ------- + Iterable[Tuple[str, TrainablePipe]] + """ + for name, pipe, *_ in self.pipeline: + if name not in disable and hasattr(pipe, "forward"): + yield name, pipe + + def to(self, device: Union[str, Optional["torch.device"]] = None): # noqa F821 + """Moves the pipeline to a given device""" + for name, pipe, *_ in self.torch_components(): + pipe.to(device) + return self + + def worker_copy(self): + return LazyCollection( + reader=self.reader.worker_copy(), + writer=self.writer, + pipeline=self.pipeline, + config=self.config, + ) + + def __dir__(self): # pragma: no cover + return (*super().__dir__(), *edspdf.data.__all__) + + def __getattr__(self, item): + return getattr(LazyCollection, item).__get__(self) + + +if TYPE_CHECKING: + # just to add read/from_* and write/to_* methods to the static type hints + LazyCollection = edspdf.data # noqa: F811 diff --git a/edspdf/pipeline.py b/edspdf/pipeline.py index 05e1d226..1fc36f67 100644 --- a/edspdf/pipeline.py +++ b/edspdf/pipeline.py @@ -1,13 +1,15 @@ import functools +import importlib +import inspect import json import os import shutil import warnings -from collections import defaultdict from contextlib import contextmanager from enum import Enum from pathlib import Path from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -24,14 +26,13 @@ from confit import Config from confit.errors import ConfitValidationError, patch_errors -from confit.utils.collections import join_path, split_path from confit.utils.xjson import Reference -from pydantic import parse_obj_as from typing_extensions import Literal import edspdf -from .accelerators.base import Accelerator, FromDoc, ToDoc +from .data.converters import FILENAME +from .lazy_collection import LazyCollection from .registry import CurriedFactory, registry from .structures import PDFDoc from .utils.collections import ( @@ -41,6 +42,9 @@ multi_tee, ) +if TYPE_CHECKING: + import torch + EMPTY_LIST = FrozenList() @@ -210,17 +214,16 @@ def add_pipe( pipe = factory if hasattr(pipe, "name"): if name is not None and name != pipe.name: - raise ValueError( + warnings.warn( "The provided name does not match the name of the component." ) + pipe.name = name else: name = pipe.name - else: - if name is None: - raise ValueError( - "The component does not have a name, so you must provide one", - ) - pipe.name = name + if name is None: + raise ValueError( + "The component does not have a name, so you must provide one", + ) assert sum([before is not None, after is not None, first]) <= 1, ( "You can only use one of before, after, or first", ) @@ -236,6 +239,18 @@ def add_pipe( self._components.insert(insertion_idx, (name, pipe)) return pipe + def ensure_doc(self, doc): + return ( + doc + if isinstance(doc, PDFDoc) + else PDFDoc(content=doc) + if isinstance(doc, bytes) + else PDFDoc( + content=doc["content"], + id=doc.get("id") or doc.get(FILENAME), + ) + ) + def __call__(self, doc: Any) -> PDFDoc: """ Apply each component successively on a document. @@ -268,9 +283,9 @@ def pipe( inputs: Any, batch_size: Optional[int] = None, *, - accelerator: Optional[Union[str, Accelerator]] = None, - to_doc: Optional[ToDoc] = None, - from_doc: FromDoc = lambda doc: doc, + accelerator: Any = None, + to_doc: Any = None, + from_doc: Any = None, ) -> Iterable[PDFDoc]: """ Process a stream of documents by applying each component successively on @@ -302,37 +317,96 @@ def pipe( if batch_size is None: batch_size = self.batch_size - if accelerator is None: - accelerator = "simple" - if isinstance(accelerator, str): - accelerator = {"@accelerator": accelerator, "batch_size": batch_size} - if isinstance(accelerator, dict): - accelerator = Config(accelerator).resolve(registry=registry) - - kwargs = { - "inputs": inputs, - "model": self, - "to_doc": parse_obj_as(Optional[ToDoc], to_doc), - "from_doc": parse_obj_as(Optional[FromDoc], from_doc), - } - for k, v in list(kwargs.items()): - if v is None: - del kwargs[k] - - with self.train(False): - return accelerator(**kwargs) + lazy_collection = LazyCollection.ensure_lazy(inputs) + + if to_doc is not None: + warnings.warn( + "The `to_doc` argument is deprecated. " + "Please use the returned value's `map` method or the read/from_{} " + "method's converter argument instead.", + DeprecationWarning, + ) + if isinstance(to_doc, str): + to_doc = {"content_field": to_doc} + if isinstance(to_doc, dict): + to_doc_dict = to_doc + + def to_doc(doc): + return PDFDoc( + content=doc[to_doc_dict["content_field"]], + id=doc[to_doc_dict["id_field"]] + if "id_field" in to_doc_dict + else None, + ) + + if not callable(to_doc): + raise ValueError( + "The `to_doc` argument must be a callable or a dictionary", + ) + lazy_collection = lazy_collection.map(to_doc) + + lazy_collection = lazy_collection.map_pipeline(self).set_processing( + batch_size=batch_size + ) + + if accelerator is not None: + warnings.warn( + "The `accelerator` argument is deprecated. " + "Please use the returned value's `set_processing` method instead.", + DeprecationWarning, + ) + if isinstance(accelerator, str): + kwargs = {} + backend = accelerator + elif isinstance(accelerator, dict): + kwargs = dict(accelerator) + backend = accelerator.pop("@accelerator", "simple") + elif "Accelerator" in type(accelerator).__name__: + backend = ( + "multiprocessing" + if "MultiProcessing" in type(accelerator).__name__ + else "simple" + ) + kwargs = accelerator.__dict__ + lazy_collection.set_processing( + backend=backend, + **kwargs, + ) + if from_doc is not None: + warnings.warn( + "The `from_doc` argument is deprecated. " + "Please use the returned value's `map` method or the write/to_{} " + "method's converter argument instead.", + DeprecationWarning, + ) + if isinstance(from_doc, dict): + from_doc_dict = from_doc + + def from_doc(doc): + return {k: getattr(doc, v) for k, v in from_doc_dict.items()} + + if not callable(from_doc): + raise ValueError( + "The `from_doc` argument must be a callable or a dictionary", + ) + lazy_collection = lazy_collection.map(from_doc) + + return lazy_collection @contextmanager def cache(self): """ Enable caching for all (trainable) components in the pipeline """ - was_not_cached = self._cache is None - if was_not_cached: - self._cache = {} + to_disable = set() + for name, pipe in self.trainable_pipes(): + if getattr(pipe, "_current_cache_id", None) is None: + pipe.enable_cache() + to_disable.add(name) yield - if was_not_cached: - self._cache = None + for name, pipe in self.trainable_pipes(): + if name in to_disable: + pipe.disable_cache() def trainable_pipes( self, disable: Sequence[str] = () @@ -353,7 +427,7 @@ def trainable_pipes( if name not in disable and hasattr(pipe, "batch_process"): yield name, pipe - def post_init(self, gold_data: Iterable[PDFDoc], exclude: Optional[set] = None): + def post_init(self, gold_data: Iterable[PDFDoc], exclude: Optional[Set] = None): """ Completes the initialization of the pipeline by calling the post_init method of all components that have one. @@ -365,7 +439,7 @@ def post_init(self, gold_data: Iterable[PDFDoc], exclude: Optional[set] = None): gold_data: Iterable[PDFDoc] The documents to use for initialization. Each component will not necessarily see all the data. - exclude: Optional[set] + exclude: Optional[Set] The names of components to exclude from initialization. This argument will be gradually updated with the names of initialized components @@ -580,7 +654,7 @@ def collate( for name, component in self.pipeline: if name in batch: component_inputs = batch[name] - batch[name] = component.collate(component_inputs, device) + batch[name] = component.collate(component_inputs) return batch def parameters(self): @@ -600,7 +674,7 @@ def named_parameters(self): seen.add(param) yield f"{name}.{param_name}", param - def to(self, device: Optional["torch.device"] = None): # noqa F821 + def to(self, device: Union[str, Optional["torch.device"]] = None): # noqa F821 """Moves the pipeline to a given device""" for name, component in self.trainable_pipes(): component.to(device) @@ -630,7 +704,7 @@ def __exit__(ctx_self, type, value, traceback): return context() - def save( + def to_disk( self, path: Union[str, Path], *, exclude: Optional[Set[str]] = None ) -> None: """ @@ -648,30 +722,6 @@ def save( process. This list will be gradually filled in place as components are saved """ - - def save_tensors(path: Path): - import safetensors.torch - - shutil.rmtree(path, ignore_errors=True) - os.makedirs(path, exist_ok=True) - tensors = defaultdict(list) - tensor_to_group = defaultdict(list) - for pipe_name, pipe in self.trainable_pipes(disable=exclude): - for key, tensor in pipe.state_dict(keep_vars=True).items(): - full_key = join_path((pipe_name, *split_path(key))) - tensors[tensor].append(full_key) - tensor_to_group[tensor].append(pipe_name) - group_to_tensors = defaultdict(set) - for tensor, group in tensor_to_group.items(): - group_to_tensors["+".join(sorted(set(group)))].add(tensor) - for group, group_tensors in group_to_tensors.items(): - sub_path = path / f"{group}.safetensors" - tensor_dict = { - "+".join(tensors[p]): p - for p in {p.data_ptr(): p for p in group_tensors}.values() - } - safetensors.torch.save_file(tensor_dict, sub_path) - exclude = set() if exclude is None else exclude path = Path(path) if isinstance(path, str) else path @@ -690,18 +740,29 @@ def save_tensors(path: Path): if "config" not in exclude: self.config.to_disk(path / "config.cfg") - extra_exclude = set(exclude) - for pipe_name, pipe in self._components: - if hasattr(pipe, "save_extra_data") and pipe_name not in extra_exclude: - pipe.save_extra_data(path / pipe_name, exclude=extra_exclude) - if "tensors" not in exclude: - save_tensors(path / "tensors") + pwd = os.getcwd() + overrides = {"components": {}} + try: + os.chdir(path) + for pipe_name, pipe in self._components: + if hasattr(pipe, "to_disk") and pipe_name not in exclude: + pipe_overrides = pipe.to_disk(Path(pipe_name), exclude=exclude) + overrides["components"][pipe_name] = pipe_overrides + finally: + os.chdir(pwd) + + config = self.config.merge(overrides) - def load_state_from_disk( + if "config" not in exclude: + config.to_disk(path / "config.cfg") + + save = to_disk + + def from_disk( self, path: Union[str, Path], *, - exclude: Set[str] = None, + exclude: Optional[Union[str, Sequence[str]]] = None, device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821 ) -> "Pipeline": """ @@ -711,59 +772,43 @@ def load_state_from_disk( ---------- path: Union[str, Path] The path to the directory to load the pipeline from - exclude: Set[str] + exclude: Optional[Union[str, Sequence[str]]] The names of the components, or attributes to exclude from the loading - process. This list will be gradually filled in place as components are - loaded + process. + device: Optional[Union[str, "torch.device"]] + Device to use when loading the tensors """ def deserialize_meta(path: Path) -> None: if path.exists(): - data = json.loads(path.read_text()) + with open(path, "r") as f: + data = json.load(f) self.meta.update(data) + # self.meta always overrides meta["vectors"] with the metadata + # from self.vocab.vectors, so set the name directly + + exclude = ( + set() + if exclude is None + else {exclude} + if isinstance(exclude, str) + else set(exclude) + ) - def deserialize_tensors(path: Path): - import safetensors.torch - - trainable_components = dict(self.trainable_pipes()) - for file_name in path.iterdir(): - pipe_names = file_name.stem.split("+") - if any(pipe_name in trainable_components for pipe_name in pipe_names): - # We only load tensors in one of the pipes since parameters - # are expected to be shared - pipe = trainable_components[pipe_names[0]] - tensor_dict = {} - for keys, tensor in safetensors.torch.load_file( - file_name, device=device - ).items(): - split_keys = [split_path(key) for key in keys.split("+")] - key = next(key for key in split_keys if key[0] == pipe_names[0]) - tensor_dict[join_path(key[1:])] = tensor - # Non-strict because tensors of a given pipeline can be shared - # between multiple files - print(f"Loading tensors of {pipe_names[0]} from {file_name}") - extra_tensors = set(tensor_dict) - set( - pipe.state_dict(keep_vars=True).keys() - ) - if extra_tensors: - warnings.warn( - f"{file_name} contains tensors that are not in the state" - f"dict of {pipe_names[0]}: {sorted(extra_tensors)}" - ) - pipe.load_state_dict(tensor_dict, strict=False) - - exclude = set() if exclude is None else exclude - + path = (Path(path) if isinstance(path, str) else path).absolute() if "meta" not in exclude: deserialize_meta(path / "meta.json") - extra_exclude = set(exclude) - for name, proc in self._components: - if hasattr(proc, "load_extra_data") and name not in extra_exclude: - proc.load_extra_data(path / name, extra_exclude) - - if "tensors" not in exclude: - deserialize_tensors(path / "tensors") + pwd = os.getcwd() + try: + os.chdir(path) + for name, proc in self._components: + if hasattr(proc, "from_disk") and name not in exclude: + proc.from_disk(Path(name), exclude=exclude) + # Convert to list here in case exclude is (default) tuple + exclude.add(name) + finally: + os.chdir(pwd) self._path = path # type: ignore[assignment] return self @@ -779,7 +824,7 @@ def load( path = Path(path) if isinstance(path, str) else path config = Config.from_disk(path / "config.cfg") self = Pipeline.from_config(config) - self.load_state_from_disk(path, exclude=exclude, device=device) + self.load(path, exclude=exclude, device=device) return self # override config property getter to remove "factory" key from components @@ -830,12 +875,10 @@ def __exit__(ctx_self, type, value, traceback): if enable is None and disable is None: raise ValueError("Expected either `enable` or `disable`") - if isinstance(disable, str): - disable = [disable] + disable = [disable] if isinstance(disable, str) else disable pipe_names = set(self.pipe_names) if enable is not None: - if isinstance(enable, str): - enable = [enable] + enable = [enable] if isinstance(enable, str) else enable if set(enable) - pipe_names: raise ValueError( "Enabled pipes {} not found in pipeline.".format( @@ -863,22 +906,26 @@ def package( self, name: Optional[str] = None, root_dir: Union[str, Path] = ".", + build_dir: Union[str, Path] = "build", + dist_dir: Union[str, Path] = "dist", artifacts_name: str = "artifacts", check_dependencies: bool = False, project_type: Optional[Literal["poetry", "setuptools"]] = None, - version: str = "0.1.0", + version: Optional[str] = None, metadata: Optional[Dict[str, Any]] = {}, distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"], config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None, isolation: bool = True, skip_build_dependency_check: bool = False, ): - from .utils.package import package + from edspdf.utils.package import package return package( pipeline=self, name=name, root_dir=root_dir, + build_dir=build_dir, + dist_dir=dist_dir, artifacts_name=artifacts_name, check_dependencies=check_dependencies, project_type=project_type, @@ -892,19 +939,99 @@ def package( def load( - config: Union[Path, str, Config], - device: Optional[Union[str, "torch.device"]] = "cpu", # noqa F821 -) -> Pipeline: - error = "The load function expects a Config or a path to a config file" - if isinstance(config, (Path, str)): - path = Path(config) - if path.is_dir(): - return Pipeline.load(path, device=device) - elif path.is_file(): - config = Config.from_disk(path) - else: - raise ValueError(error) - elif not isinstance(config, Config): + model: Union[Path, str, Config], + overrides: Optional[Dict[str, Any]] = None, + *, + exclude: Optional[Union[str, Iterable[str]]] = None, + device: Optional[Union[str, "torch.device"]] = "cpu", +): + """ + Load a pipeline from a config file or a directory. + + Examples + -------- + + ```{ .python .no-check } + import edspdf + + nlp = edspdf.load( + "path/to/config.cfg", + overrides={"components": {"my_component": {"arg": "value"}}}, + ) + ``` + + Parameters + ---------- + model: Union[Path, str, Config] + The config to use for the pipeline, or the path to a config file or a directory. + overrides: Optional[Dict[str, Any]] + Overrides to apply to the config when loading the pipeline. These are the + same parameters as the ones used when initializing the pipeline. + exclude: Optional[Union[str, Iterable[str]]] + The names of the components, or attributes to exclude from the loading + process. :warning: The `exclude` argument will be mutated in place. + device: Optional[Union[str, "torch.device"]] + Device to use when loading the tensors + + Returns + ------- + Pipeline + """ + error = ( + "The load function expects either :\n" + "- a confit Config object\n" + "- the path of a config file (.cfg file)\n" + "- the path of a trained model\n" + "- the name of an installed pipeline package\n" + f"but got {model!r} which is neither" + ) + if isinstance(model, (Path, str)): + path = Path(model) + is_dir = path.is_dir() + is_config = path.is_file() and path.suffix == ".cfg" + try: + module = importlib.import_module(model) + is_package = True + except (ImportError, AttributeError, TypeError): + module = None + is_package = False + if is_dir and is_package: + warnings.warn( + "The path provided is both a directory and a package : edspdf will " + "load the package. To load from the directory instead, please pass the " + f'path as "./{path}" instead.' + ) + if is_dir: + path = (Path(path) if isinstance(path, str) else path).absolute() + config = Config.from_disk(path / "config.cfg") + if overrides: + config = config.merge(overrides) + pwd = os.getcwd() + try: + os.chdir(path) + nlp = Pipeline.from_config(config) + nlp.from_disk(path, exclude=exclude, device=device) + finally: + os.chdir(pwd) + return nlp + elif is_config: + model = Config.from_disk(path) + elif is_package: + # Load as package + available_kwargs = { + "overrides": overrides, + "exclude": exclude, + "device": device, + } + signature_kwargs = inspect.signature(module.load).parameters + kwargs = { + name: available_kwargs[name] + for name in signature_kwargs + if name in available_kwargs + } + return module.load(**kwargs) + + if not isinstance(model, Config): raise ValueError(error) - return Pipeline.from_config(config) + return Pipeline.from_config(model) diff --git a/edspdf/pipes/classifiers/trainable.py b/edspdf/pipes/classifiers/trainable.py index f879af48..e7d32648 100644 --- a/edspdf/pipes/classifiers/trainable.py +++ b/edspdf/pipes/classifiers/trainable.py @@ -155,9 +155,9 @@ def preprocess_supervised(self, doc: PDFDoc) -> Dict[str, Any]: ], } - def collate(self, batch, device: torch.device) -> Dict: + def collate(self, batch) -> Dict: collated = { - "embedding": self.embedding.collate(batch["embedding"], device), + "embedding": self.embedding.collate(batch["embedding"]), "doc_id": batch["doc_id"], } if "labels" in batch: @@ -167,7 +167,6 @@ def collate(self, batch, device: torch.device) -> Dict: batch["labels"], data_dims=("line",), full_names=("sample", "page", "line"), - device=device, dtype=torch.long, ), } @@ -208,28 +207,28 @@ def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]: b.label = self.label_voc.decode(label) if b.text != "" else None return docs - def save_extra_data(self, path: Path, exclude: Set): + def to_disk(self, path: Path, exclude: Set): if self.name in exclude: return exclude.add(self.name) - self.embedding.save_extra_data(path / "embedding", exclude) - os.makedirs(path, exist_ok=True) with (path / "label_voc.json").open("w") as f: json.dump(self.label_voc.indices, f) - def load_extra_data(self, path: Path, exclude: Set): + return super().to_disk(path, exclude) + + def from_disk(self, path: Path, exclude: Set): if self.name in exclude: return exclude.add(self.name) - self.embedding.load_extra_data(path / "embedding", exclude) - label_voc_indices = dict(self.label_voc.indices) with (path / "label_voc.json").open("r") as f: self.label_voc.indices = json.load(f) self.update_weights_from_vocab_(label_voc_indices) + + super().from_disk(path, exclude) diff --git a/edspdf/pipes/embeddings/box_layout_embedding.py b/edspdf/pipes/embeddings/box_layout_embedding.py index 0bd8c29f..e27661a4 100644 --- a/edspdf/pipes/embeddings/box_layout_embedding.py +++ b/edspdf/pipes/embeddings/box_layout_embedding.py @@ -9,7 +9,7 @@ BoxLayoutPreprocessor, ) from edspdf.registry import registry -from edspdf.trainable_pipe import NestedSequences, TrainablePipe +from edspdf.trainable_pipe import TrainablePipe @registry.factory.register("box-layout-embedding") @@ -71,8 +71,8 @@ def __init__( def preprocess(self, doc): return self.box_preprocessor.preprocess(doc) - def collate(self, batch: NestedSequences, device: torch.device) -> BoxLayoutBatch: - return self.box_preprocessor.collate(batch, device) + def collate(self, batch) -> BoxLayoutBatch: + return self.box_preprocessor.collate(batch) @classmethod def _make_embed(cls, n_positions, size, mode): diff --git a/edspdf/pipes/embeddings/box_layout_preprocessor.py b/edspdf/pipes/embeddings/box_layout_preprocessor.py index ae9715f3..10559f6f 100644 --- a/edspdf/pipes/embeddings/box_layout_preprocessor.py +++ b/edspdf/pipes/embeddings/box_layout_preprocessor.py @@ -74,11 +74,10 @@ def preprocess(self, doc: PDFDoc, supervision: bool = False): "last_page": [[b.page_num == last_p for b in p.text_boxes] for p in pages], } - def collate(self, batch, device: torch.device) -> BoxLayoutBatch: + def collate(self, batch) -> BoxLayoutBatch: kw = { "full_names": ["sample", "page", "line"], "data_dims": ["line"], - "device": device, } return { diff --git a/edspdf/pipes/embeddings/huggingface_embedding.py b/edspdf/pipes/embeddings/huggingface_embedding.py index 16ccf076..d1d34cf4 100644 --- a/edspdf/pipes/embeddings/huggingface_embedding.py +++ b/edspdf/pipes/embeddings/huggingface_embedding.py @@ -181,7 +181,7 @@ def preprocess(self, doc: PDFDoc): return res - def collate(self, batch, device): + def collate(self, batch): # Flatten most of these arrays to process batches page per page and # not sample per sample @@ -274,29 +274,24 @@ def collate(self, batch, device): kw = dict( full_names=("sample", "page", "subword"), data_dims=("subword",), - device=device, ) collated = { "input_ids": as_folded_tensor(batch["input_ids"], **kw, dtype=torch.long), "bbox": as_folded_tensor(batch["bbox"], **kw, dtype=torch.long), - "windows": windows.to(device), - "indexer": indexer[line_window_indices].to(device), - "line_window_indices": indexer[line_window_indices].as_tensor().to(device), - "line_window_offsets_flat": line_window_offsets_flat.to(device), + "windows": windows, + "indexer": indexer[line_window_indices], + "line_window_indices": indexer[line_window_indices].as_tensor(), + "line_window_offsets_flat": line_window_offsets_flat, } if self.use_image: - collated["pixel_values"] = ( - torch.stack( - [ - torch.from_numpy(page_pixels) - for sample_pages in batch["pixel_values"] - for page_pixels in sample_pages - ], - dim=0, - ) - .repeat_interleave(torch.as_tensor(windows_count_per_page), dim=0) - .to(device) - ) + collated["pixel_values"] = torch.stack( + [ + torch.from_numpy(page_pixels) + for sample_pages in batch["pixel_values"] + for page_pixels in sample_pages + ], + dim=0, + ).repeat_interleave(torch.as_tensor(windows_count_per_page), dim=0) return collated def forward(self, batch): diff --git a/edspdf/pipes/embeddings/simple_text_embedding.py b/edspdf/pipes/embeddings/simple_text_embedding.py index cc2d8691..f7674102 100644 --- a/edspdf/pipes/embeddings/simple_text_embedding.py +++ b/edspdf/pipes/embeddings/simple_text_embedding.py @@ -154,7 +154,7 @@ def post_init(self, gold_data, exclude: set): self.update_weights_from_vocab_(vocab_items_before) - def save_extra_data(self, path: Path, exclude: Set): + def to_disk(self, path: Path, exclude: Set): if self.name in exclude: return @@ -169,7 +169,9 @@ def save_extra_data(self, path: Path, exclude: Set): with (path / "norm_voc.json").open("w") as f: json.dump(self.norm_voc.indices, f) - def load_extra_data(self, path: Path, exclude: Set): + return super().to_disk(path, exclude) + + def from_disk(self, path: Path, exclude: Set): if self.name in exclude: return @@ -191,6 +193,8 @@ def load_extra_data(self, path: Path, exclude: Set): self.update_weights_from_vocab_(vocab_items_before) + super().from_disk(path, exclude) + def preprocess(self, doc: PDFDoc): tokens_shape = [] tokens_prefix = [] @@ -228,10 +232,9 @@ def preprocess(self, doc: PDFDoc): "tokens_norm": tokens_norm, } - def collate(self, batch, device: torch.device) -> BoxTextEmbeddingInputBatch: + def collate(self, batch) -> BoxTextEmbeddingInputBatch: kwargs = dict( dtype=torch.long, - device=device, data_dims=("word",), full_names=( "sample", diff --git a/edspdf/processing/__init__.py b/edspdf/processing/__init__.py new file mode 100644 index 00000000..65a80d6e --- /dev/null +++ b/edspdf/processing/__init__.py @@ -0,0 +1,9 @@ +from typing import TYPE_CHECKING + +from edspdf.utils.lazy_module import lazify + +lazify() + +if TYPE_CHECKING: + from .simple import execute_simple_backend + from .multiprocessing import execute_multiprocessing_backend diff --git a/edspdf/processing/multiprocessing.py b/edspdf/processing/multiprocessing.py new file mode 100644 index 00000000..841c0df7 --- /dev/null +++ b/edspdf/processing/multiprocessing.py @@ -0,0 +1,893 @@ +from __future__ import annotations + +import copyreg +import gc +import io +import logging +import multiprocessing +import multiprocessing.reduction +import os +import sys +import tempfile +import warnings +from contextlib import nullcontext +from multiprocessing.connection import wait +from random import shuffle +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, +) + +import dill +from typing_extensions import TypedDict + +from edspdf import PDFDoc +from edspdf.lazy_collection import LazyCollection +from edspdf.utils.collections import batchify, flatten + +if TYPE_CHECKING: + import torch + + from edspdf.trainable_pipe import TrainablePipe + +Stage = TypedDict( + "Stage", + { + "cpu_components": List[Tuple[Callable, Dict]], + "gpu_component": Optional[Any], + }, +) + + +class ForkingPickler(dill.Pickler): + """ + ForkingPickler that uses dill instead of pickle to transfer objects between + processes. + """ + + _extra_reducers = {} + _copyreg_dispatch_table = copyreg.dispatch_table + + def __new__(cls, *args, **kwargs): + result = dill.Pickler.__new__(ForkingPickler) + # Python would not call __init__ if the original + # multiprocessing.reduction.ForkingPickler called, leading to a call to this + # monkey-patched __new__ method, because [original cls] != [type of result] + # (see https://docs.python.org/3/reference/datamodel.html#basic-customization) + # so we force the call to __init__ here + if not isinstance(result, cls): + result.__init__(*args, **kwargs) + return result + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + self.dispatch_table = self._copyreg_dispatch_table.copy() + self.dispatch_table.update(self._extra_reducers) + + @classmethod + def register(cls, type, reduce): + """Register a reduce function for a type.""" + cls._extra_reducers[type] = reduce + + @classmethod + def dumps(cls, obj, protocol=None, *args, **kwds): + buf = io.BytesIO() + cls(buf, protocol, *args, **kwds).dump(obj) + return buf.getbuffer() + + loads = dill.loads + + +def replace_pickler(): + """ + Replace the default pickler used by multiprocessing with dill. + "multiprocess" didn't work for obscure reasons (maybe the reducers / dispatchers + are not propagated between multiprocessing and multiprocess => torch specific + reducers might be missing ?), so this patches multiprocessing directly. + directly. + + For some reason I do not explain, this has a massive impact on the performance of + the multiprocessing backend. With the original pickler, the performance can be + up to 2x slower than with our custom one. + """ + old_pickler = multiprocessing.reduction.ForkingPickler + + before = ( + dict(ForkingPickler._extra_reducers), + old_pickler.__new__, + old_pickler.dumps, + old_pickler.loads, + old_pickler.register, + ) + + old_pickler.__new__ = ForkingPickler.__new__ + old_pickler.dumps = ForkingPickler.dumps + old_pickler.loads = ForkingPickler.loads + old_pickler.register = ForkingPickler.register + ForkingPickler._extra_reducers.update( + multiprocessing.reduction.ForkingPickler._extra_reducers + ) + + def revert(): + ( + ForkingPickler._extra_reducers, + old_pickler.__new__, + old_pickler.dumps, + old_pickler.loads, + old_pickler.register, + ) = before + + return revert + + +# Should we check if the multiprocessing module of edspdf +# is responsible for this child process before replacing the pickler ? +if ( + multiprocessing.current_process() != "MainProcess" + or hasattr(multiprocessing, "parent_process") + and multiprocessing.parent_process() is not None +): + replace_pickler() + +DEBUG = True + +debug = ( + (lambda *args, flush=False, **kwargs: print(*args, **kwargs, flush=True)) + if DEBUG + else lambda *args, **kwargs: None +) + +try: + import torch + + # Torch may still be imported as a namespace package, so we can access the + # torch.save and torch.load functions + torch_save = torch.save + torch_load = torch.load + + MAP_LOCATION = None + + try: + from accelerate.hooks import AlignDevicesHook + + # We need to replace the "execution_device" attribute of the AlignDevicesHook + # using map_location when unpickling the lazy collection + + def save_align_devices_hook(pickler: Any, obj: Any): + pickler.save_reduce(load_align_devices_hook, (obj.__dict__,), obj=obj) + + def load_align_devices_hook(state): + state["execution_device"] = MAP_LOCATION + new_obj = AlignDevicesHook.__new__(AlignDevicesHook) + new_obj.__dict__.update(state) + return new_obj + + except ImportError: + AlignDevicesHook = None + + def dump(*args, **kwargs): + # We need to replace the "execution_device" attribute of the AlignDevicesHook + # using map_location when pickling the lazy collection + old = None + try: + if AlignDevicesHook is not None: + old = dill.Pickler.dispatch.get(AlignDevicesHook) + dill.Pickler.dispatch[AlignDevicesHook] = save_align_devices_hook + dill.settings["recurse"] = True + return torch_save(*args, pickle_module=dill, **kwargs) + finally: + dill.settings["recurse"] = False + if AlignDevicesHook is not None: + del dill.Pickler.dispatch[AlignDevicesHook] + if old is not None: # pragma: no cover + dill.Pickler.dispatch[AlignDevicesHook] = old + + def load(*args, map_location=None, **kwargs): + global MAP_LOCATION + MAP_LOCATION = map_location + if torch.__version__ >= "2.1" and isinstance(args[0], str): + kwargs["mmap"] = True + result = torch_load( + *args, + pickle_module=dill, + map_location=map_location, + **kwargs, + ) + MAP_LOCATION = None + return result + +except (ImportError, AttributeError): # pragma: no cover + + def load(file, *args, map_location=None, **kwargs): + # check if path + if isinstance(file, str): + with open(file, "rb") as f: + return dill.load(f, *args, **kwargs) + return dill.load(file, *args, **kwargs) + + dump = dill.dump + +T = TypeVar("T") + + +class Exchanger: + def __init__( + self, + mp: multiprocessing.context.BaseContext, + num_stages: int, + num_gpu_workers: int, + num_cpu_workers: int, + gpu_worker_devices: List[Any], + ): + # queue for cpu input tasks + self.gpu_worker_devices = gpu_worker_devices + self.num_cpu_workers = num_cpu_workers + self.num_gpu_workers = num_gpu_workers + # We add prioritized queue at the end for STOP signals + self.cpu_inputs_queues = [ + [mp.Queue()] + [mp.SimpleQueue() for _ in range(num_stages + 1)] + # The input queue is not shared between processes, since calling `wait` + # on a queue reader from multiple processes may lead to a deadlock + for _ in range(num_cpu_workers) + ] + self.gpu_inputs_queues = [ + [mp.Queue() for _ in range(num_stages + 1)] for _ in range(num_gpu_workers) + ] + self.outputs_queue = mp.Queue() + self.num_stages = num_stages + + # noinspection PyUnresolvedReferences + def get_cpu_task(self, idx, allow_new_task: bool = True): + this_cpu_input_queues = self.cpu_inputs_queues[idx] + if not allow_new_task: + this_cpu_input_queues = this_cpu_input_queues[1:] + queue_readers = wait([queue._reader for queue in this_cpu_input_queues]) + stage, queue = next( + (stage, q) + for stage, q in reversed(list(enumerate(self.cpu_inputs_queues[idx]))) + if q._reader in queue_readers + ) + item = queue.get() + return stage, item + + def put_cpu(self, item, stage, idx): + return self.cpu_inputs_queues[idx][stage].put(item) + + # noinspection PyUnresolvedReferences + def get_gpu_task(self, idx): + queue_readers = wait([queue._reader for queue in self.gpu_inputs_queues[idx]]) + stage, queue = next( + (stage, q) + for stage, q in reversed(list(enumerate(self.gpu_inputs_queues[idx]))) + if q._reader in queue_readers + ) + item = queue.get() + return stage, item + + def put_gpu(self, item, stage, idx): + return self.gpu_inputs_queues[idx][stage].put(item) + + def put_results(self, items): + self.outputs_queue.put(items) + + def iter_results(self): + for out in iter(self.outputs_queue.get, None): + yield out + + +class CPUWorker: + def __init__( + self, + cpu_idx: int, + exchanger: Exchanger, + gpu_pipe_names: List[str], + lazy_collection_path: str, + device: Union[str, "torch.device"], + ): + super(CPUWorker, self).__init__() + + self.cpu_idx = cpu_idx + self.exchanger = exchanger + self.gpu_pipe_names = gpu_pipe_names + self.lazy_collection_path = lazy_collection_path + self.device = device + + def _run(self): + # Cannot pass torch tensor during init i think ? otherwise i get + # ValueError: bad value(s) in fds_to_keep + # mp._prctl_pr_set_pdeathsig(signal.SIGINT) + + num_cpu = self.exchanger.num_cpu_workers + + had_error = False + expect_new_tasks = True + + def read_tasks(): + nonlocal next_batch_id, expect_new_tasks, had_error + + if lc.batch_unit == "lines": + + def formula(batch): + return sum(len(doc.content_boxes) for doc in batch) + + elif lc.batch_unit == "pages": + + def formula(batch): + return sum(len(doc.pages) for doc in batch) + + else: + formula = len + + while expect_new_tasks or len(active_batches) > 0: + stage, task = self.exchanger.get_cpu_task( + idx=self.cpu_idx, + allow_new_task=len(active_batches) < 2, + ) + # stage, task = next(iterator) + # Prioritized STOP signal: something bad happened in another process + # -> stop listening to input queues and raise StopIteration (return) + if task is None and stage == self.exchanger.num_stages + 1: + return + # Non prioritized STOP signal: there are no more tasks to process and + # we should smoothly stop (wait that there are no more active tasks, + # and finalize the writer) + if stage == 0 and task is None: + expect_new_tasks = False + continue + + # If first stage, we receive tasks that may require batching + # again => we split them into chunks + if stage == 0: + task_id, fragments = task + chunks = list( + batchify(lc.reader.read_worker(fragments), lc.chunk_size) + ) + for chunk_idx, docs in enumerate(chunks): + for pipe, kwargs, tokenizer in stages[0]["cpu_components"]: + if hasattr(pipe, "batch_process"): + docs = pipe.batch_process(docs) + else: + docs: List[PDFDoc] = [ + pipe(doc, **kwargs) for doc in docs + ] + + if lc.sort_by_size: + docs.sort(key=lambda doc: len(doc.content_boxes)) + + batches = [ + batch + for batch in batchify( + docs, + batch_size=lc.batch_size, + formula=formula, + ) + ] + + for batch_idx, batch in enumerate(batches): + assert len(batch) > 0 + batch_id = next_batch_id + + # We mark the task id only for the last batch of a task + # since the purpose of storing the task id is to know + # when the worker has finished processing the task, which + # is true only when the last batch has been processed + active_batches[batch_id] = ( + batch, + task_id + if (batch_idx == len(batches) - 1) + and (chunk_idx == len(chunks) - 1) + else None, + ) + next_batch_id += num_cpu + # gpu_idx = None + # batch_id = we have just created a new batch + # result from the last stage = None + yield stage, (None, batch_id, None) + else: + yield stage, task + + lc: LazyCollection = load(self.lazy_collection_path, map_location=self.device) + # for name, pipe, *rest in lc.pipeline: + # move_to_device(pipe, self.device) + + stages: List[Stage] = [{"cpu_components": [], "gpu_component": None}] + for name, pipe, *rest in lc.pipeline: + if name in self.gpu_pipe_names: + stages[-1]["gpu_component"] = pipe + stages.append({"cpu_components": [], "gpu_component": None}) + else: + stages[-1]["cpu_components"].append((pipe, *rest)) + + # Start at cpu_idx to avoid having all workers sending their + # first batch (0 % num_device, cf below) to the same gpu + next_batch_id = self.cpu_idx + active_batches = {} + + logging.info(f"Starting cpu {self.cpu_idx}, PID {os.getpid()}") + self.exchanger.outputs_queue.put(None) + for stage, (gpu_idx, batch_id, result) in read_tasks(): + if had_error: + continue # pragma: no cover + try: + docs, task_id = active_batches.pop(batch_id) + for name, pipe, *rest in lc.pipeline: + if hasattr(pipe, "enable_cache"): + pipe.enable_cache(batch_id) + if stage > 0: + gpu_pipe = stages[stage - 1]["gpu_component"] + docs = gpu_pipe.postprocess(docs, result) # type: ignore + + for pipe, kwargs, tokenizer in stages[stage]["cpu_components"]: + if hasattr(pipe, "batch_process"): + docs = pipe.batch_process(docs) + else: + docs = [pipe(doc, **kwargs) for doc in docs] + + # if stage == 0: + # print("BATCH ID", batch_id, [len(d) for d in docs]) + + gpu_pipe: "TrainablePipe" = stages[stage]["gpu_component"] + if gpu_pipe is not None: + preprocessed = gpu_pipe.make_batch(docs) # type: ignore + active_batches[batch_id] = (docs, task_id) + if gpu_idx is None: + gpu_idx = batch_id % len(self.exchanger.gpu_worker_devices) + collated = gpu_pipe.collate(preprocessed) + collated = gpu_pipe.batch_to_device( + collated, + device=self.exchanger.gpu_worker_devices[gpu_idx], + ) + self.exchanger.put_gpu( + item=(self.cpu_idx, batch_id, collated), + idx=gpu_idx, + stage=stage, + ) + # batch_id += num_cpu # check if not needed ? + else: + for name, pipe, *rest in lc.pipeline: + if hasattr(pipe, "disable_cache"): + pipe.disable_cache(batch_id) + results, count = ( + lc.writer.write_worker(docs) + if lc.writer is not None + else (docs, len(docs)) + ) + self.exchanger.put_results( + ( + results, + count, + self.cpu_idx, + task_id, + ) + ) + except BaseException as e: + had_error = True + import traceback + + print(traceback.format_exc(), flush=True) + self.exchanger.put_results((e, 0, self.cpu_idx, None)) + + if not had_error: + if lc.writer is not None: + results, count = lc.writer.finalize() + if count > 0: + self.exchanger.put_results((results, count, self.cpu_idx, None)) + # We need to drain the queues of GPUWorker fed inputs (pre-moved to GPU) + # to ensure no tensor allocated on producer processes (CPUWorker via + # collate) are left in consumer processes + task = True # anything but None + while task is not None: + stage, task = self.exchanger.get_cpu_task(self.cpu_idx) + + def run(self): + self._run() + gc.collect() + try: + sys.modules["torch"].cuda.empty_cache() + except (AttributeError, KeyError): # pragma: no cover + pass + + +class GPUWorker: + def __init__( + self, + gpu_idx, + exchanger: Exchanger, + gpu_pipe_names: List[str], + lazy_collection_path: str, + device: Union[str, "torch.device"], + ): + super().__init__() + + self.device = device + self.gpu_idx = gpu_idx + self.exchanger = exchanger + + self.gpu_pipe_names = gpu_pipe_names + self.lazy_collection_path = lazy_collection_path + + def _run(self): + import torch + + # mp._prctl_pr_set_pdeathsig(signal.SIGINT) + had_error = False + + lc = load(self.lazy_collection_path, map_location=self.device) + stage_components = [ + pipe + # move_to_device(pipe, self.device) + for name, pipe, *_ in lc.pipeline + if name in self.gpu_pipe_names + ] + + del lc + logging.info(f"Starting gpu {self.gpu_idx}") + self.exchanger.outputs_queue.put(None) + with torch.no_grad(): + while True: + stage, task = self.exchanger.get_gpu_task(self.gpu_idx) + if task is None: + break + if had_error: + continue # pragma: no cover + + try: + cpu_idx, batch_id, batch = task + pipe = stage_components[stage] + pipe.enable_cache(batch_id) + res = pipe.module_forward(batch) + self.exchanger.put_cpu( + item=(self.gpu_idx, batch_id, res), + stage=stage + 1, + idx=cpu_idx, + ) + if stage == len(stage_components) - 1: + pipe.disable_cache(batch_id) + del batch, task + except torch.cuda.OutOfMemoryError as e: + import traceback + + # print(debug_sizes, flush=True) + print(traceback.format_exc(), flush=True) + + # import gc + # gc.collect() + # torch.cuda.empty_cache() + # + # cpu_idx, batch_id, batch = task + # component = stage_components[stage] + # res = component.module_forward(batch) + # self.exchanger.put_cpu( + # item=(self.gpu_idx, batch_id, res), + # stage=stage + 1, + # idx=cpu_idx, + # ) + # del batch, task + self.exchanger.put_results((e, 0, None, None)) + except BaseException as e: + had_error = True + import traceback + + self.exchanger.put_results((e, 0, None, None)) + task = batch = res = None # noqa + # We need to drain the queues of CPUWorker fed inputs (pre-moved to GPU) + # to ensure no tensor allocated on producer processes (CPUWorker via + # collate) are left in consumer processes + task = True # anything but None + while task is not None: + stage, task = self.exchanger.get_gpu_task(self.gpu_idx) + + def run(self): + self._run() + gc.collect() + sys.modules["torch"].cuda.empty_cache() + + +DEFAULT_MAX_CPU_WORKERS = 4 + + +def execute_multiprocessing_backend( + lc: LazyCollection, +): + """ + If you have multiple CPU cores, and optionally multiple GPUs, we provide the + `multiprocessing` backend that allows to run the inference on multiple + processes. + + This accelerator dispatches the batches between multiple workers + (data-parallelism), and distribute the computation of a given batch on one or two + workers (model-parallelism): + + - a `CPUWorker` which handles the non deep-learning components and the + preprocessing, collating and postprocessing of deep-learning components + - a `GPUWorker` which handles the forward call of the deep-learning components + + If no GPU is available, no `GPUWorker` is started, and the `CPUWorkers` handle + the forward call of the deep-learning components as well. + + The advantage of dedicating a worker to the deep-learning components is that it + allows to prepare multiple batches in parallel in multiple `CPUWorker`, and ensure + that the `GPUWorker` never wait for a batch to be ready. + + The overall architecture described in the following figure, for 3 CPU workers and 2 + GPU workers. + +
+ +
+ + Here is how a small pipeline with rule-based components and deep-learning components + is distributed between the workers: + +
+ +
+ + """ + try: + TrainablePipe = sys.modules["edspdf.trainable_pipe"].TrainablePipe + except (KeyError, AttributeError): # pragma: no cover + TrainablePipe = None + + steps = lc.pipeline + num_cpu_workers = lc.num_cpu_workers + num_gpu_workers = lc.num_gpu_workers + show_progress = lc.show_progress + process_start_method = lc.process_start_method + + # Infer which pipes should be accelerated on GPU + gpu_steps_candidates = ( + [name for name, component, *_ in steps if isinstance(component, TrainablePipe)] + if TrainablePipe is not None + else [] + ) + gpu_pipe_names = ( + gpu_steps_candidates if lc.gpu_pipe_names is None else lc.gpu_pipe_names + ) + if set(gpu_pipe_names) - set(gpu_steps_candidates): + raise ValueError( + "GPU accelerated pipes {} could not be found in the model".format( + sorted(set(gpu_pipe_names) - set(gpu_steps_candidates)) + ) + ) + + requires_gpu = ( + num_gpu_workers is None + and len(gpu_pipe_names) + or num_gpu_workers is not None + and num_gpu_workers > 0 + ) + + num_devices = 0 + if requires_gpu: + import torch + + num_devices = torch.cuda.device_count() + logging.info(f"Number of available devices: {num_devices}") + + if num_gpu_workers is None: + num_gpu_workers = num_devices + + if num_gpu_workers and process_start_method == "fork": + warnings.warn( + "Using fork start method with GPU workers may lead to deadlocks. " + "Consider using process_start_method='spawn' instead." + ) + + process_start_method = process_start_method or "spawn" + else: + num_gpu_workers = 0 + + default_method = multiprocessing.get_start_method() + if process_start_method is not None and default_method != process_start_method: + logging.info(f"Switching process start method to {process_start_method}") + + mp = multiprocessing.get_context(process_start_method) + max_workers = max(min(mp.cpu_count() - num_gpu_workers, DEFAULT_MAX_CPU_WORKERS), 0) + num_cpu_workers = ( + (num_gpu_workers or max_workers) + if num_cpu_workers is None + else max_workers + num_cpu_workers + 1 + if num_cpu_workers < 0 + else num_cpu_workers + ) + + if num_gpu_workers == 0: + gpu_pipe_names = [] + + gpu_worker_devices = ( + [ + f"cuda:{gpu_idx * num_devices // num_gpu_workers}" + for gpu_idx in range(num_gpu_workers) + ] + if requires_gpu and lc.gpu_worker_devices is None + else [] + if lc.gpu_worker_devices is None + else lc.gpu_worker_devices + ) + cpu_worker_devices = ( + ["cpu"] * num_cpu_workers + if lc.cpu_worker_devices is None + else lc.cpu_worker_devices + ) + assert len(cpu_worker_devices) == num_cpu_workers + assert len(gpu_worker_devices) == num_gpu_workers + if num_cpu_workers == 0: # pragma: no cover + ( + num_cpu_workers, + num_gpu_workers, + cpu_worker_devices, + gpu_worker_devices, + gpu_pipe_names, + ) = (num_gpu_workers, 0, gpu_worker_devices, [], []) + + exchanger = Exchanger( + mp, + num_stages=len(gpu_pipe_names), + num_cpu_workers=num_cpu_workers, + num_gpu_workers=num_gpu_workers, + gpu_worker_devices=gpu_worker_devices, + ) + + lc = lc.to("cpu") + + cpu_workers = [] + gpu_workers = [] + + with tempfile.NamedTemporaryFile(delete=False) as fp: + dump(lc.worker_copy(), fp) + fp.close() + + revert_pickler = replace_pickler() + + for gpu_idx in range(num_gpu_workers): + gpu_workers.append( + mp.Process( + target=GPUWorker( + gpu_idx=gpu_idx, + exchanger=exchanger, + gpu_pipe_names=gpu_pipe_names, + lazy_collection_path=fp.name, + device=gpu_worker_devices[gpu_idx], + ).run + ) + ) + + for cpu_idx in range(num_cpu_workers): + cpu_workers.append( + mp.Process( + target=CPUWorker( + cpu_idx=cpu_idx, + exchanger=exchanger, + gpu_pipe_names=gpu_pipe_names, + lazy_collection_path=fp.name, + device=cpu_worker_devices[cpu_idx], + ).run + ) + ) + + logging.info(f"Main PID {os.getpid()}") + + logging.info( + f"Starting {num_cpu_workers} cpu workers and {num_gpu_workers} gpu workers on " + f"{gpu_worker_devices}, with accelerated pipes: {gpu_pipe_names}", + ) + + for worker in (*cpu_workers, *gpu_workers): + worker.start() + + for i in range(len((*cpu_workers, *gpu_workers))): + assert exchanger.outputs_queue.get() is None + + os.unlink(fp.name) + + logging.info("Workers are ready") + + def process(): + try: + num_max_enqueued = 1 + # Number of input/output batch per process + outputs_iterator = exchanger.iter_results() + + cpu_worker_indices = list(range(num_cpu_workers)) + inputs_iterator = lc.reader.read_main() + workloads = [{} for _ in cpu_worker_indices] + + bar = nullcontext() + if show_progress: + from tqdm import tqdm + + bar = tqdm(smoothing=0.1, mininterval=5.0) + + with bar: + for input_task_id, items in enumerate( + batchify( + iterable=inputs_iterator, + batch_size=lc.chunk_size, + drop_last=False, + formula=lambda x: sum(item[1] for item in x), + ) + ): + batch = [item[0] for item in items] + batch_size = sum(item[1] for item in items) + + if all( + sum(wl.values()) >= lc.chunk_size * num_max_enqueued + for wl in workloads + ): + outputs, count, cpu_idx, output_task_id = next(outputs_iterator) + if isinstance(outputs, BaseException): + raise outputs + if show_progress: + bar.update(count) + yield outputs + if output_task_id is not None: + workloads[cpu_idx].pop(output_task_id, None) + + # Shuffle to ensure the first process does not receive all the + # documents in case of workload equality + shuffle(cpu_worker_indices) + cpu_idx = min( + cpu_worker_indices, + key=lambda i: sum(workloads[i].values()), + ) + exchanger.put_cpu((input_task_id, batch), stage=0, idx=cpu_idx) + workloads[cpu_idx][input_task_id] = batch_size + + # Inform the CPU workers that there are no more tasks to process + for i, worker in enumerate(cpu_workers): + exchanger.cpu_inputs_queues[i][0].put(None) + + while any(workloads): + outputs, count, cpu_idx, output_task_id = next(outputs_iterator) + if isinstance(outputs, BaseException): + raise outputs # pragma: no cover + if show_progress: + bar.update(count) + yield outputs + workloads[cpu_idx].pop(output_task_id, None) + finally: + revert_pickler() + + # Send gpu and cpu process the order to stop processing data + # We use the prioritized queue to ensure the stop signal is processed + # before the next batch of data + for i, worker in enumerate(gpu_workers): + exchanger.gpu_inputs_queues[i][-1].put(None) + for i, worker in enumerate(cpu_workers): + exchanger.cpu_inputs_queues[i][-1].put(None) + + # Enqueue a final non prioritized STOP signal to ensure there remains no + # data in the queues (cf drain loop in CPUWorker / GPUWorker) + for i, worker in enumerate(gpu_workers): + exchanger.gpu_inputs_queues[i][0].put(None) + for i, worker in enumerate(gpu_workers): + worker.join(timeout=5) + for i, worker in enumerate(cpu_workers): + exchanger.cpu_inputs_queues[i][0].put(None) + for i, worker in enumerate(cpu_workers): + worker.join(timeout=1) + + # If a worker is still alive, kill it + # This should not happen, but for a reason I cannot explain, it does in + # some CPU workers sometimes when we catch an error, even though each run + # method of the workers completes cleanly. Maybe this has something to do + # with the cleanup of these processes ? + for i, worker in enumerate(gpu_workers): # pragma: no cover + if worker.is_alive(): + logging.error(f"Killing gpu worker {i}") + worker.kill() + for i, worker in enumerate(cpu_workers): # pragma: no cover + if worker.is_alive(): + logging.error(f"Killing cpu worker {i}") + worker.kill() + + gen = process() + return lc.writer.write_main(gen) if lc.writer is not None else flatten(gen) diff --git a/edspdf/processing/simple.py b/edspdf/processing/simple.py new file mode 100644 index 00000000..6746a201 --- /dev/null +++ b/edspdf/processing/simple.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import sys +from contextlib import nullcontext +from typing import TYPE_CHECKING + +from edspdf.utils.collections import batchify, flatten + +if TYPE_CHECKING: + from edspdf.lazy_collection import LazyCollection + + +def execute_simple_backend( + lc: LazyCollection, +): + """ + This is the default execution mode which batches the documents and processes each + batch on the current process in a sequential manner. + """ + try: + no_grad = sys.modules["torch"].no_grad + except (KeyError, AttributeError): + no_grad = nullcontext + reader = lc.reader + writer = lc.writer + show_progress = lc.show_progress + + def process(): + bar = nullcontext() + if show_progress: + from tqdm import tqdm + + bar = tqdm() + + with bar: + for batch in batchify( + ( + subtask + for task, count in reader.read_main() + for subtask in reader.read_worker([task]) + ), + batch_size=lc.batch_size, + ): + with no_grad(), lc.cache(): + for name, pipe, kwargs in lc.pipeline: + if hasattr(pipe, "batch_process"): + batch = pipe.batch_process(batch, **kwargs) + else: + batch = [pipe(doc, **kwargs) for doc in batch] + + if writer is not None: + result, count = writer.write_worker(batch) + if show_progress: + bar.update(count) + yield result + else: + if show_progress: + bar.update(len(batch)) + yield batch + if writer is not None: + result, count = writer.finalize() + if show_progress: + bar.update(count) + if count: + yield result + + gen = process() + return flatten(gen) if writer is None else writer.write_main(gen) diff --git a/edspdf/registry.py b/edspdf/registry.py index 641b0eb2..5219b4cd 100644 --- a/edspdf/registry.py +++ b/edspdf/registry.py @@ -220,3 +220,5 @@ class registry(RegistryCollection): misc = Registry(("edspdf", "misc"), entry_points=True) adapter = Registry(("edspdf", "adapter"), entry_points=True) accelerator = Registry(("edspdf", "accelerator"), entry_points=True) + readers = Registry(("edspdf", "readers"), entry_points=True) + writers = Registry(("edspdf", "writers"), entry_points=True) diff --git a/edspdf/trainable_pipe.py b/edspdf/trainable_pipe.py index e87ae84a..0681d392 100644 --- a/edspdf/trainable_pipe.py +++ b/edspdf/trainable_pipe.py @@ -1,28 +1,34 @@ +import os from abc import ABCMeta from enum import Enum from functools import wraps -from pathlib import Path from typing import ( Any, + Callable, Dict, Generic, Iterable, Optional, Sequence, + Set, + Tuple, TypeVar, Union, ) +import safetensors.torch import torch from edspdf.pipeline import Pipeline from edspdf.structures import PDFDoc from edspdf.utils.collections import batch_compress_dict, decompress_dict -NestedSequences = Dict[str, Union["NestedSequences", Sequence]] -NestedTensors = Dict[str, Union["NestedSequences", torch.Tensor]] -InputBatch = TypeVar("InputBatch", bound=NestedTensors) -OutputBatch = TypeVar("OutputBatch", bound=NestedTensors) +BatchInput = TypeVar("BatchInput", bound=Dict[str, Any]) +BatchOutput = TypeVar("BatchOutput", bound=Dict[str, Any]) +Scorer = Callable[[Sequence[Tuple[PDFDoc, PDFDoc]]], Union[float, Dict[str, Any]]] + +ALL_CACHES = object() +_caches = {} class CacheEnum(str, Enum): @@ -36,19 +42,24 @@ def hash_batch(batch): return hash(tuple(id(item) for item in batch)) elif not isinstance(batch, dict): return id(batch) - return hash((tuple(batch.keys()), tuple(map(hash_batch, batch.values())))) + if "__batch_hash__" in batch: + return batch["__batch_hash__"] + batch_hash = hash((tuple(batch.keys()), tuple(map(hash_batch, batch.values())))) + batch["__batch_hash__"] = batch_hash + return batch_hash def cached_preprocess(fn): @wraps(fn) def wrapped(self: "TrainablePipe", doc: PDFDoc): - if self.pipeline is None or self.pipeline._cache is None: + if self._current_cache_id is None: return fn(self, doc) - cache_id = (id(self), "preprocess", id(doc)) - if cache_id in self.pipeline._cache: - return self.pipeline._cache[cache_id] + cache_key = ("preprocess", f"{type(self)}<{id(self)}>: {id(doc)}") + cache = _caches[self._current_cache_id] + if cache_key in cache: + return cache[cache_key] res = fn(self, doc) - self.pipeline._cache[cache_id] = res + cache[cache_key] = res return res return wrapped @@ -57,31 +68,30 @@ def wrapped(self: "TrainablePipe", doc: PDFDoc): def cached_preprocess_supervised(fn): @wraps(fn) def wrapped(self: "TrainablePipe", doc: PDFDoc): - if self.pipeline is None or self.pipeline._cache is None: + if self._current_cache_id is None: return fn(self, doc) - cache_id = (id(self), "preprocess_supervised", id(doc)) - if cache_id in self.pipeline._cache: - return self.pipeline._cache[cache_id] + cache_key = ("preprocess_supervised", f"{type(self)}<{id(self)}>: {id(doc)}") + cache = _caches[self._current_cache_id] + if cache_key in cache: + return cache[cache_key] res = fn(self, doc) - self.pipeline._cache[cache_id] = res + cache[cache_key] = res return res return wrapped def cached_collate(fn): - import torch - @wraps(fn) - def wrapped(self: "TrainablePipe", batch: Dict, device: torch.device): - if self.pipeline is None or self.pipeline._cache is None: - return fn(self, batch, device) - cache_id = (id(self), "collate", hash_batch(batch)) - if cache_id in self.pipeline._cache: - return self.pipeline._cache[cache_id] - res = fn(self, batch, device) - self.pipeline._cache[cache_id] = res - res["cache_id"] = cache_id + def wrapped(self: "TrainablePipe", batch: Dict): + if self._current_cache_id is None: + return fn(self, batch) + cache_key = ("collate", f"{type(self)}<{id(self)}>: {hash_batch(batch)}") + cache = _caches[self._current_cache_id] + if cache_key in cache: + return cache[cache_key] + res = fn(self, batch) + cache[cache_key] = res return res return wrapped @@ -90,13 +100,35 @@ def wrapped(self: "TrainablePipe", batch: Dict, device: torch.device): def cached_forward(fn): @wraps(fn) def wrapped(self: "TrainablePipe", batch): - if self.pipeline is None or self.pipeline._cache is None: + # Convert args and kwargs to a dictionary matching fn signature + if self._current_cache_id is None: return fn(self, batch) - cache_id = (id(self), "collate", hash_batch(batch)) - if cache_id in self.pipeline._cache: - return self.pipeline._cache[cache_id] + cache_key = ("forward", f"{type(self)}<{id(self)}>: {hash_batch(batch)}") + cache = _caches[self._current_cache_id] + if cache_key in cache: + return cache[cache_key] res = fn(self, batch) - self.pipeline._cache[cache_id] = res + cache[cache_key] = res + return res + + return wrapped + + +def cached_batch_to_device(fn): + @wraps(fn) + def wrapped(self: "TrainablePipe", batch, device): + # Convert args and kwargs to a dictionary matching fn signature + if self._current_cache_id is None: + return fn(self, batch, device) + cache_key = ( + "batch_to_device", + f"{type(self)}<{id(self)}>: {hash_batch(batch)}", + ) + cache = _caches[self._current_cache_id] + if cache_key in cache: + return cache[cache_key] + res = fn(self, batch, device) + cache[cache_key] = res return res return wrapped @@ -112,6 +144,10 @@ def __new__(mcs, name, bases, class_dict): ) if "collate" in class_dict: class_dict["collate"] = cached_collate(class_dict["collate"]) + if "batch_to_device" in class_dict: + class_dict["batch_to_device"] = cached_batch_to_device( + class_dict["batch_to_device"] + ) if "forward" in class_dict: class_dict["forward"] = cached_forward(class_dict["forward"]) @@ -120,7 +156,7 @@ def __new__(mcs, name, bases, class_dict): class TrainablePipe( torch.nn.Module, - Generic[OutputBatch], + Generic[BatchOutput], metaclass=TrainablePipeMeta, ): """ @@ -133,15 +169,30 @@ class TrainablePipe( for components that share a common subcomponent. """ + call_super_init = True + def __init__(self, pipeline: Optional[Pipeline], name: Optional[str]): super().__init__() - self.pipeline = pipeline self.name = name - self.cfg = {} - self._preprocess_cache = {} - self._preprocess_supervised_cache = {} - self._collate_cache = {} - self._forward_cache = {} + self._current_cache_id = None + + def enable_cache(self, cache_id="default"): + self._current_cache_id = cache_id + _caches.setdefault(cache_id, {}) + for name, component in self.named_component_children(): + if hasattr(component, "enable_cache"): + component.enable_cache(cache_id) + + def disable_cache(self, cache_id=ALL_CACHES): + if cache_id is ALL_CACHES: + _caches.clear() + else: + if cache_id in _caches: + del _caches[cache_id] + self._current_cache_id = None + for name, component in self.named_component_children(): + if hasattr(component, "disable_cache"): + component.disable_cache(cache_id) @property def device(self): @@ -152,45 +203,12 @@ def named_component_children(self): if isinstance(module, TrainablePipe): yield name, module - def save_extra_data(self, path: Path, exclude: set): - """ - Dumps vocabularies indices to json files - - Parameters - ---------- - path: Path - Path to the directory where the files will be saved - exclude: Set - The set of component names to exclude from saving - This is useful when components are repeated in the pipeline. - """ - if self.name in exclude: - return - exclude.add(self.name) - for name, component in self.named_component_children(): - if hasattr(component, "save_extra_data"): - component.save_extra_data(path / name, exclude) - - def load_extra_data(self, path: Path, exclude: set): - """ - Loads vocabularies indices from json files - - Parameters - ---------- - path: Path - Path to the directory where the files will be loaded - exclude: Set - The set of component names to exclude from loading - This is useful when components are repeated in the pipeline. - """ - if self.name in exclude: - return - exclude.add(self.name) - for name, component in self.named_component_children(): - if hasattr(component, "load_extra_data"): - component.load_extra_data(path / name, exclude) + def named_component_modules(self): + for name, module in self.named_modules(): + if isinstance(module, TrainablePipe): + yield name, module - def post_init(self, gold_data: Iterable[PDFDoc], exclude: set): + def post_init(self, gold_data: Iterable[PDFDoc], exclude: Set[str]): """ This method completes the attributes of the component, by looking at some documents. It is especially useful to build vocabularies or detect the labels @@ -205,9 +223,10 @@ def post_init(self, gold_data: Iterable[PDFDoc], exclude: set): This argument will be gradually updated with the names of initialized components """ - if self.name in exclude: + repr_id = object.__repr__(self) + if repr_id in exclude: return - exclude.add(self.name) + exclude.add(repr_id) for name, component in self.named_component_children(): if hasattr(component, "post_init"): component.post_init(gold_data, exclude=exclude) @@ -233,52 +252,79 @@ def preprocess(self, doc: PDFDoc) -> Dict[str, Any]: for name, component in self.named_component_children() } - def collate(self, batch: NestedSequences, device: torch.device) -> InputBatch: + def collate(self, batch: Dict[str, Any]) -> BatchInput: """ Collate the batch of features into a single batch of tensors that can be used by the forward method of the component. Parameters ---------- - batch: NestedSequences + batch: Dict[str, Any] Batch of features - device: torch.device - Device on which the tensors should be moved Returns ------- - InputBatch + BatchInput Dictionary (optionally nested) containing the collated tensors """ return { - name: component.collate(batch[name], device) + name: component.collate(batch[name]) for name, component in self.named_component_children() } - def forward(self, batch: InputBatch) -> OutputBatch: + def batch_to_device( + self, + batch: BatchInput, + device: Optional[Union[str, torch.device]], + ) -> BatchInput: + """ + Move the batch of tensors to the specified device. + + Parameters + ---------- + batch: BatchInput + Batch of tensors + device: Optional[Union[str, torch.device]] + Device to move the tensors to + + Returns + ------- + BatchInput + """ + return { + name: ( + value.to(device) + if hasattr(value, "to") + else getattr(self, name).batch_to_device(value, device=device) + if hasattr(self, name) + else value + ) + for name, value in batch.items() + } + + def forward(self, batch: BatchInput) -> BatchOutput: """ - Perform the forward pass of the neural network, i.e, apply transformations - over the collated features to compute new embeddings, probabilities, losses, etc + Perform the forward pass of the neural network. Parameters ---------- - batch: InputBatch + batch: BatchInput Batch of tensors (nested dictionary) computed by the collate method Returns ------- - OutputBatch + BatchOutput """ raise NotImplementedError() - def module_forward(self, batch: InputBatch) -> OutputBatch: + def module_forward(self, *args, **kwargs): """ This is a wrapper around `torch.nn.Module.__call__` to avoid conflict with the [`TrainablePipe.__call__`][edspdf.trainable_pipe.TrainablePipe.__call__] method. """ - return torch.nn.Module.__call__(self, batch) + return torch.nn.Module.__call__(self, *args, **kwargs) def make_batch( self, @@ -323,9 +369,11 @@ def batch_process(self, docs: Sequence[PDFDoc]) -> Sequence[PDFDoc]: Sequence[PDFDoc] Batch of updated documents """ + device = next((p.device for p in self.parameters()), "cpu") with torch.no_grad(): batch = self.make_batch(docs) - inputs = self.collate(batch, device=self.device) + inputs = self.collate(batch) + inputs = self.batch_to_device(inputs, device=device) if hasattr(self, "compiled"): res = self.compiled(inputs) else: @@ -334,7 +382,7 @@ def batch_process(self, docs: Sequence[PDFDoc]) -> Sequence[PDFDoc]: return docs def postprocess( - self, docs: Sequence[PDFDoc], batch: OutputBatch + self, docs: Sequence[PDFDoc], batch: BatchOutput ) -> Sequence[PDFDoc]: """ Update the documents with the predictions of the neural network, for instance @@ -346,7 +394,7 @@ def postprocess( ---------- docs: Sequence[PDFDoc] Batch of documents - batch: OutputBatch + batch: BatchOutput Batch of predictions, as returned by the forward method Returns @@ -391,3 +439,41 @@ def __call__(self, doc: PDFDoc) -> PDFDoc: PDFDoc """ return self.batch_process([doc])[0] + + def to_disk(self, path, exclude: Optional[Set[str]]): + if object.__repr__(self) in exclude: + return + exclude.add(object.__repr__(self)) + overrides = {} + for name, component in self.named_component_children(): + if hasattr(component, "to_disk"): + pipe_overrides = component.to_disk(path / name, exclude=exclude) + if pipe_overrides: + overrides[name] = pipe_overrides + tensor_dict = { + n: p + for n, p in self.named_parameters() + if object.__repr__(p) not in exclude + } + os.makedirs(path, exist_ok=True) + safetensors.torch.save_file(tensor_dict, path / "parameters.safetensors") + exclude.update(object.__repr__(p) for p in tensor_dict.values()) + return overrides + + def from_disk(self, path, exclude: Optional[Set[str]]): + if object.__repr__(self) in exclude: + return + exclude.add(object.__repr__(self)) + for name, component in self.named_component_children(): + if hasattr(component, "from_disk"): + component.from_disk(path / name, exclude=exclude) + tensor_dict = safetensors.torch.load_file(path / "parameters.safetensors") + self.load_state_dict(tensor_dict, strict=False) + + @property + def load_extra_data(self): + return self.from_disk + + @property + def save_extra_data(self): + return self.to_disk diff --git a/edspdf/utils/collections.py b/edspdf/utils/collections.py index f0c37db9..7ffe7861 100644 --- a/edspdf/utils/collections.py +++ b/edspdf/utils/collections.py @@ -1,8 +1,18 @@ import copy import itertools -import math from collections import defaultdict -from typing import Any, Dict, Iterable, List, Mapping, Sequence, TypeVar +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Sequence, + TypeVar, + Union, +) It = TypeVar("It", bound=Iterable) T = TypeVar("T") @@ -37,16 +47,46 @@ def nest_dict(flat: Dict[str, Any]) -> Dict[str, Any]: def ld_to_dl(ld: Iterable[Mapping[str, T]]) -> Dict[str, List[T]]: + """ + Convert a list of dictionaries to a dictionary of lists + + Parameters + ---------- + ld: Iterable[Mapping[str, T]] + The list of dictionaries + + Returns + ------- + Dict[str, List[T]] + The dictionary of lists + """ ld = list(ld) - return {k: [dic[k] for dic in ld] for k in ld[0]} + return {k: [dic.get(k) for dic in ld] for k in (ld[0] if len(ld) else ())} + + +def dl_to_ld(dl: Mapping[str, Sequence[Any]]) -> Iterator[Dict[str, Any]]: + """ + Convert a dictionary of lists to a list of dictionaries + Parameters + ---------- + dl: Mapping[str, Sequence[Any]] + The dictionary of lists -def dl_to_ld(dl: Mapping[str, Sequence[T]]) -> List[Dict[str, T]]: - return [dict(zip(dl, t)) for t in zip(*dl.values())] + Returns + ------- + List[Dict[str, Any]] + The list of dictionaries + """ + return (dict(zip(dl, t)) for t in zip(*dl.values())) -def flatten(seq: Sequence[Sequence["T"]]) -> List["T"]: - return list(itertools.chain.from_iterable(seq)) +def flatten(items): + for item in items: + if isinstance(item, list): + yield from flatten(item) + else: + yield item FLATTEN_TEMPLATE = """\ @@ -64,16 +104,17 @@ def rec(current, path): keys[id(current)].append(path) return for key, value in current.items(): - rec(value, (*path, key)) + if not key.startswith("$"): + rec(value, (*path, key)) rec(obj, ()) code = FLATTEN_TEMPLATE.format( "{" + "\n".join( - "'{}': root{},".format( - "|".join(map("/".join, key_list)), - "".join(f"['{k}']" for k in key_list[0]), + "{}: root{},".format( + repr("|".join(map("/".join, key_list))), + "".join(f"[{repr(k)}]" for k in key_list[0]), ) for key_list in keys.values() ) @@ -83,6 +124,19 @@ def rec(current, path): class batch_compress_dict: + """ + Compress a sequence of dictionaries in which values that occur multiple times are + deduplicated. The corresponding keys will be merged into a single string using + the "|" character as a separator. + This is useful to preserve referential identities when decompressing the dictionary + after it has been serialized and deserialized. + + Parameters + ---------- + seq: Iterable[Dict[str, Any]] + Sequence of dictionaries to compress + """ + __slots__ = ("flatten", "seq") def __init__(self, seq: Iterable[Dict[str, Any]]): @@ -109,7 +163,23 @@ def __next__(self) -> Dict[str, List]: return self.flatten(item) -def decompress_dict(seq): +def decompress_dict(seq: Union[Iterable[Dict[str, Any]], Dict[str, Any]]): + """ + Decompress a dictionary of lists into a sequence of dictionaries. + This function assumes that the dictionary structure was obtained using the + `batch_compress_dict` class. + Keys that were merged into a single string using the "|" character as a separator + will be split into a nested dictionary structure. + + Parameters + ---------- + seq: Union[Iterable[Dict[str, Any]], Dict[str, Any]] + The dictionary to decompress or a sequence of dictionaries to decompress + + Returns + ------- + + """ obj = ld_to_dl(seq) if isinstance(seq, Sequence) else seq res = {} for key, value in obj.items(): @@ -122,26 +192,33 @@ def decompress_dict(seq): return res -class batchify(Iterable[List[T]]): - def __init__(self, iterable: Iterable[T], batch_size: int): - self.iterable = iter(iterable) - self.batch_size = batch_size - try: - self.length = math.ceil(len(iterable) / batch_size) - except (AttributeError, TypeError): - pass - - def __len__(self): - return self.length - - def __iter__(self): - return self - - def __next__(self): - batch = list(itertools.islice(self.iterable, self.batch_size)) - if len(batch) == 0: - raise StopIteration() - return batch +def batchify( + iterable: Iterable[T], + batch_size: int, + drop_last: bool = False, + formula: Callable = len, +) -> Iterable[List[T]]: + """ + Yields batch that contain at most `batch_size` elements. + If an item contains more than `batch_size` elements, it will be yielded as a single + batch. + + Parameters + ---------- + iterable: Iterable[T] + batch_size: int + drop_last: bool + formula: Callable + """ + batch = [] + for item in iterable: + next_size = formula(batch + [item]) + if next_size > batch_size and len(batch) > 0: + yield batch + batch = [] + batch.append(item) + if len(batch) > 0 and not drop_last: + yield batch def get_attr_item(base, attr): diff --git a/edspdf/utils/lazy_module.py b/edspdf/utils/lazy_module.py new file mode 100644 index 00000000..ddc75d61 --- /dev/null +++ b/edspdf/utils/lazy_module.py @@ -0,0 +1,108 @@ +# flake8: noqa: F811 +import ast +import importlib +import inspect +import os + + +def lazify(): + def _get_module_paths(file): + """ + Reads the content of the current file, parses it with ast and store the + import path for future potential imports. This is useful to only import + the module that is requested and avoid loading all the modules at once, since + some of them are quite heavy, or contain dependencies that are not always + available. + + For instance: + > from .trainable.span_qualifier.factory import create_component as + span_qualifier is stored in the cache as: + > module_paths["span_qualifier"] = "trainable.span_qualifier.factory" + + Returns + ------- + Dict[str, Tuple[str, str]] + The absolute path of the current file. + """ + module_path = os.path.abspath(file) + with open(module_path, "r") as f: + module_content = f.read() + module_ast = ast.parse(module_content) + module_paths = {} + for node in module_ast.body: + # Lookup TYPE_CHECKING + if not ( + isinstance(node, ast.If) + and ( + ( + isinstance(node.test, ast.Name) + and node.test.id == "TYPE_CHECKING" + ) + or ( + isinstance(node.test, ast.Attribute) + and node.test.attr == "TYPE_CHECKING" + ) + ) + ): + continue + for import_node in node.body: + if isinstance(import_node, ast.ImportFrom): + for name in import_node.names: + module_paths[name.asname or name.name] = ( + import_node.module, + name.name, + ) + + return module_paths + + def __getattr__(name): + """ + Imports the actual module if it is in the module_paths dict. + + Parameters + ---------- + name + + Returns + ------- + + """ + if name in module_paths: + module_path, module_name = module_paths[name] + result = getattr( + importlib.__import__( + module_path, + fromlist=[module_name], + globals=module_globals, + level=1, + ), + module_name, + ) + module_globals[name] = result + return result + raise AttributeError(f"module {__name__} has no attribute {name}") + + def __dir__(): + """ + Returns the list of available modules. + + Returns + ------- + List[str] + """ + return __all__ + + # Access upper frame + module_globals = inspect.currentframe().f_back.f_globals + + module_paths = _get_module_paths(module_globals["__file__"]) + + __all__ = list(module_paths.keys()) + + module_globals.update( + { + "__getattr__": __getattr__, + "__dir__": __dir__, + "__all__": __all__, + } + ) diff --git a/edspdf/utils/package.py b/edspdf/utils/package.py index 8440fa65..23b97730 100644 --- a/edspdf/utils/package.py +++ b/edspdf/utils/package.py @@ -6,7 +6,7 @@ import sys from contextlib import contextmanager from pathlib import Path -from types import FunctionType +from types import FunctionType, ModuleType from typing import ( TYPE_CHECKING, Any, @@ -25,9 +25,13 @@ from build.__main__ import build_package, build_package_via_sdist from confit import Cli from dill._dill import save_function as dill_save_function +from dill._dill import save_module as dill_save_module from dill._dill import save_type as dill_save_type -from importlib_metadata import PackageNotFoundError -from importlib_metadata import version as get_version + +try: + import importlib_metadata +except ImportError: # pragma: no cover + import importlib.metadata as importlib_metadata from loguru import logger from typing_extensions import Literal @@ -36,22 +40,25 @@ py_version = f"{sys.version_info.major}.{sys.version_info.minor}" -def get_package(obj_type: Type): +def get_package(obj: Type): # Retrieve the __package__ attribute of the module of a type, if possible. # And returns the package version as well try: - module_name = obj_type.__module__ + if isinstance(obj, ModuleType): + module_name = obj.__name__ + else: + module_name = obj.__module__ if module_name == "__main__": - raise Exception(f"Could not find package of type {obj_type}") + raise Exception(f"Could not find package of {obj}") module = __import__(module_name, fromlist=["__package__"]) - package = module.__package__ + package = module.__package__.split(".")[0] try: - version = get_version(package) - except (PackageNotFoundError, ValueError): + version = importlib_metadata.version(package) + except (importlib_metadata.PackageNotFoundError, ValueError): return None return package, version except (ImportError, AttributeError): - raise Exception(f"Cound not find package of type {obj_type}") + raise Exception(f"Cound not find package of type {obj}") def save_type(pickler, obj, *args, **kwargs): @@ -68,11 +75,19 @@ def save_function(pickler, obj, *args, **kwargs): return dill_save_function(pickler, obj, *args, **kwargs) +def save_module(pickler, obj, *args, **kwargs): + package_name = get_package(obj) + if package_name is not None: + pickler.packages.add(package_name) + return dill_save_module(pickler, obj, *args, **kwargs) + + class PackagingPickler(dill.Pickler): dispatch = dill.Pickler.dispatch.copy() dispatch[FunctionType] = save_function dispatch[type] = save_type + dispatch[ModuleType] = save_module def __init__(self, *args, **kwargs): self.file = io.BytesIO() @@ -81,7 +96,7 @@ def __init__(self, *args, **kwargs): def get_deep_dependencies(obj): - pickler = PackagingPickler() + pickler = PackagingPickler(byref=True) pickler.dump(obj) return sorted(pickler.packages) @@ -132,6 +147,8 @@ def validate(cls, value, config=None): # Initialize the builder try: builder = SdistBuilder(poetry, None, None) + # Get the list of files to include + files = builder.find_files_to_add() except ModuleOrPackageNotFound: if not poetry.package.packages: print([]) @@ -139,15 +156,13 @@ def validate(cls, value, config=None): print([ {k: v for k, v in { - "include": include._include, - "from": include.source, - "formats": include.formats, - }.items()} + "include": getattr(include, '_include'), + "from": getattr(include, 'source', None), + "formats": getattr(include, 'formats', None), + }.items() if v} for include in builder._module.includes ]) -# Get the list of files to include -files = builder.find_files_to_add() # Print the list of files for file in files: @@ -161,12 +176,16 @@ def validate(cls, value, config=None): import edspdf from pathlib import Path +from typing import Optional, Dict, Any __version__ = {__version__} -def load(device: "torch.device" = "cpu") -> edspdf.Pipeline: +def load( + overrides: Optional[Dict[str, Any]] = None, + device: "torch.device" = "cpu" +) -> edspdf.Pipeline: artifacts_path = Path(__file__).parent / "{artifacts_dir}" - model = edspdf.load(artifacts_path, device=device) + model = edspdf.load(artifacts_path, overrides=overrides, device=device) return model """ @@ -194,11 +213,11 @@ def __init__( self, pyproject: Optional[Dict[str, Any]], pipeline: Union[Path, "edspdf.Pipeline"], - version: str, + version: Optional[str], name: Optional[ModuleName], root_dir: Path = ".", - build_name: Path = "build", - out_dir: Path = "dist", + build_dir: Path = "build", + dist_dir: Path = "dist", artifacts_name: ModuleName = "artifacts", dependencies: Optional[Sequence[Tuple[str, str]]] = None, metadata: Optional[Dict[str, Any]] = {}, @@ -215,10 +234,11 @@ def __init__( self.dependencies = dependencies self.pipeline = pipeline self.artifacts_name = artifacts_name - self.out_dir = self.root_dir / out_dir + self.dist_dir = ( + dist_dir if Path(dist_dir).is_absolute() else self.root_dir / dist_dir + ) with self.ensure_pyproject(metadata): - python_executable = ( Path(self.poetry_bin_path).read_text().split("\n")[0][2:] ) @@ -236,7 +256,9 @@ def __init__( out = result.stdout.decode().strip().split("\n") self.poetry_packages = eval(out[0]) - self.build_dir = root_dir / build_name / self.name + self.build_dir = ( + build_dir if Path(build_dir).is_absolute() else root_dir / build_dir + ) / self.name self.file_paths = [self.root_dir / file_path for file_path in out[1:]] logger.info(f"root_dir: {self.root_dir}") @@ -262,7 +284,7 @@ def ensure_pyproject(self, metadata): "poetry": { **metadata, "name": self.name, - "version": self.version, + "version": self.version or "0.1.0", "dependencies": { "python": f">={py_version},<4.0", **{ @@ -319,7 +341,7 @@ def build( distributions = ["wheel"] build_call( srcdir=self.build_dir, - outdir=self.out_dir, + outdir=self.dist_dir, distributions=distributions, config_settings=config_settings, isolation=isolation, @@ -335,12 +357,13 @@ def update_pyproject(self): f"project" ) - old_version = self.pyproject["tool"]["poetry"]["version"] - self.pyproject["tool"]["poetry"]["version"] = self.version - logger.info( - f"Replaced project version {old_version!r} with {self.version!r} in poetry " - f"based project" - ) + if self.version is not None: + old_version = self.pyproject["tool"]["poetry"]["version"] + self.pyproject["tool"]["poetry"]["version"] = self.version + logger.info( + f"Replaced project version {old_version!r} with {self.version!r} in " + f"poetry based project" + ) # Adding artifacts to include in pyproject.toml snake_name = snake_case(self.name.lower()) @@ -381,7 +404,7 @@ def make_src_dir(self): build_artifacts_dir, ) else: - self.pipeline.save(build_artifacts_dir) + self.pipeline.to_disk(build_artifacts_dir) os.makedirs(package_dir, exist_ok=True) with open(package_dir / "__init__.py", mode="a") as f: f.write( @@ -397,10 +420,12 @@ def package( pipeline: Union[Path, "edspdf.Pipeline"], name: Optional[ModuleName] = None, root_dir: Path = ".", + build_dir: Path = "build", + dist_dir: Path = "dist", artifacts_name: ModuleName = "artifacts", check_dependencies: bool = False, project_type: Optional[Literal["poetry", "setuptools"]] = None, - version: str = "0.1.0", + version: Optional[str] = None, metadata: Optional[Dict[str, Any]] = {}, distributions: Optional[Sequence[Literal["wheel", "sdist"]]] = ["wheel"], config_settings: Optional[Mapping[str, Union[str, Sequence[str]]]] = None, @@ -442,6 +467,8 @@ def package( name=name, version=version, root_dir=root_dir, + build_dir=build_dir, + dist_dir=dist_dir, artifacts_name=artifacts_name, dependencies=dependencies, metadata=metadata, diff --git a/pyproject.toml b/pyproject.toml index ed6024d6..8a0c83a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,16 +142,14 @@ concurrency = ["multiprocessing"] omit = [ "tests/*", ] -# omit = [ -# "edspdf/accelerators/multiprocessing.py", -# ] -exclude_also = [ +exclude_lines = [ "def __repr__", "if __name__ == .__main__.:", "@overload", "pragma: no cover", "raise .*Error", "raise .*Exception", + "warn\\(", "if __name__ == .__main__.:", "if (self[.])?name in exclude:", "if TYPE_CHECKING:", diff --git a/tests/core/test_data.py b/tests/core/test_data.py new file mode 100644 index 00000000..21ef872a --- /dev/null +++ b/tests/core/test_data.py @@ -0,0 +1,66 @@ +import json + +import pandas as pd +import pytest + +import edspdf +import edspdf.accelerators.multiprocessing +from edspdf.data.converters import CONTENT, FILENAME +from edspdf.utils.collections import flatten + + +def box_converter(x): + return [ + { + "id": x.id, + "page_num": b.page_num, + "x0": b.x0, + "x1": b.x1, + "y0": b.y0, + "y1": b.y1, + } + for b in x.content_boxes + ] + + +def full_file_converter(x): + return { + FILENAME: x.id, + CONTENT: x.content, + "annotations": [ + { + "page_num": b.page_num, + "x0": b.x0, + "x1": b.x1, + "y0": b.y0, + "y1": b.y1, + } + for b in x.content_boxes + ], + } + + +@pytest.mark.parametrize("write_mode", ["parquet", "pandas", "iterable", "files"]) +def test_from_files(frozen_pipeline, write_mode, tmp_path, change_test_dir): + docs = edspdf.data.read_files("../resources") + docs = docs.map_pipeline(frozen_pipeline) + if write_mode == "parquet": + docs.write_parquet( + tmp_path / "parquet" / "test.parquet", + converter=box_converter, + ) + df = pd.read_parquet(tmp_path / "parquet" / "test.parquet") + elif write_mode == "pandas": + df = docs.to_pandas(converter=box_converter) + elif write_mode == "iterable": + df = pd.DataFrame(flatten(docs.to_iterable(converter=box_converter))) + else: + docs.write_files( + tmp_path / "files", + converter=full_file_converter, + ) + records = [] + for f in (tmp_path / "files").rglob("*.json"): + records.extend(json.loads(f.read_text())["annotations"]) + df = pd.DataFrame(records) + assert len(df) == 91 diff --git a/tests/core/test_pipeline.py b/tests/core/test_pipeline.py index 554ef205..42a445ff 100644 --- a/tests/core/test_pipeline.py +++ b/tests/core/test_pipeline.py @@ -223,11 +223,13 @@ def score(golds, preds): def test_cache(pipeline: Pipeline, dummy_dataset: Path, pdf: bytes): + from edspdf.trainable_pipe import _caches + pipeline(pdf) with pipeline.cache(): pipeline(pdf) - assert len(pipeline._cache) > 0 + assert len(_caches["default"]) > 0 assert pipeline._cache is None @@ -249,11 +251,11 @@ def test_different_names(pipeline: Pipeline): extractor = PdfMinerExtractor(pipeline=pipeline, name="custom_name") - with pytest.raises(ValueError) as exc_info: + with pytest.warns() as record: pipeline.add_pipe(extractor, name="extractor") assert "The provided name does not match the name of the component." in str( - exc_info.value + record[0].message ) @@ -340,7 +342,7 @@ def error_pipe(doc: PDFDoc): return doc -def test_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf): +def test_deprecated_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf): edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 accelerator = edspdf.accelerators.multiprocessing.MultiprocessingAccelerator( batch_size=2, @@ -364,6 +366,30 @@ def test_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf): ) +def test_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf): + edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 + iterator = chain.from_iterable( + [ + {"content": pdf}, + {"content": letter_pdf}, + ] + for i in range(5) + ) + docs = edspdf.data.from_iterable( + iterator, converter=lambda x: PDFDoc(content=x["content"]) + ) + docs = docs.map_pipeline(frozen_pipeline) + docs = docs.set_processing( + batch_size=2, + num_gpu_workers=1, + num_cpu_workers=1, + gpu_worker_devices=["cpu"], + sort_by_size=True, + batch_unit="lines", + ) + docs = docs.to_iterable(converter=lambda x: {"text": x.aggregated_texts}) + + def test_multiprocessing_rb_error(pipeline, pdf, letter_pdf): edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 pipeline.add_pipe(error_pipe, name="error", after="extractor") @@ -391,9 +417,9 @@ def __init__(self, *args, **kwargs): def preprocess(self, doc): return {"num_boxes": len(doc.content_boxes), "doc_id": doc.id} - def collate(self, batch, device): + def collate(self, batch): return { - "num_boxes": torch.tensor(batch["num_boxes"], device=device), + "num_boxes": torch.tensor(batch["num_boxes"]), "doc_id": batch["doc_id"], } diff --git a/tests/utils/test_package.py b/tests/utils/test_package.py index 28168b12..8ebadbc8 100644 --- a/tests/utils/test_package.py +++ b/tests/utils/test_package.py @@ -133,12 +133,16 @@ def test_package_with_files(frozen_pipeline, tmp_path, package_name): import edspdf from pathlib import Path +from typing import Optional, Dict, Any __version__ = '0.1.0' -def load(device: "torch.device" = "cpu") -> edspdf.Pipeline: +def load( + overrides: Optional[Dict[str, Any]] = None, + device: "torch.device" = "cpu" +) -> edspdf.Pipeline: artifacts_path = Path(__file__).parent / "artifacts" - model = edspdf.load(artifacts_path, device=device) + model = edspdf.load(artifacts_path, overrides=overrides, device=device) return model """ )