From 06f142a40f379d689727335a6a29c7753655b3cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 14 Jun 2024 16:47:51 +0200 Subject: [PATCH] refacto: update trainable components and data api --- changelog.md | 7 + docs/trainable-pipes.md | 9 +- edspdf/data/parquet.py | 116 ++++----- edspdf/lazy_collection.py | 31 +++ edspdf/pipes/classifiers/trainable.py | 9 +- edspdf/processing/multiprocessing.py | 343 ++++++++++++++++++-------- edspdf/processing/simple.py | 39 +-- edspdf/processing/utils.py | 85 +++++++ edspdf/trainable_pipe.py | 41 +-- tests/core/test_data.py | 4 +- 10 files changed, 471 insertions(+), 213 deletions(-) create mode 100644 edspdf/processing/utils.py diff --git a/changelog.md b/changelog.md index 7d7873f..1f7f419 100644 --- a/changelog.md +++ b/changelog.md @@ -2,10 +2,17 @@ ## Unreleased +### Changed + +- Default to fp16 when inferring with gpu +- Support `inputs` parameter in `TrainablePipe.postprocess(...)` method (as in edsnlp) +- We now check that the user isn't trying to write a single file in a split fashion (when `write_in_worker is True ` or `num_rows_per_file is not None`) and raise an error if they do + ### Fixed - Batches full of empty content boxes no longer crash the `huggingface-embedding` component - Ensure models are always loaded in non training mode +- Improved performance of `edsnlp.data` methods over a filesystem (`fs` parameter) ## v0.9.1 diff --git a/docs/trainable-pipes.md b/docs/trainable-pipes.md index ae6f37b..78a7dac 100644 --- a/docs/trainable-pipes.md +++ b/docs/trainable-pipes.md @@ -53,7 +53,7 @@ Additionally, there is a fifth method: Here is an example of a trainable component: ```python -from typing import Any, Dict, Iterable, Sequence +from typing import Any, Dict, Iterable, Sequence, List import torch from tqdm import tqdm @@ -114,7 +114,12 @@ class MyComponent(TrainablePipe): return output - def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]: + def postprocess( + self, + docs: Sequence[PDFDoc], + output: Dict, + inputs: List[Dict[str, Any]], + ) -> Sequence[PDFDoc]: # Annotate the docs with the outputs of the forward method ... return docs diff --git a/edspdf/data/parquet.py b/edspdf/data/parquet.py index 813ff98..74d742d 100644 --- a/edspdf/data/parquet.py +++ b/edspdf/data/parquet.py @@ -1,8 +1,9 @@ -import os +import sys from itertools import chain from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union +import fsspec import pyarrow.dataset import pyarrow.fs import pyarrow.parquet @@ -17,6 +18,7 @@ from edspdf.lazy_collection import LazyCollection from edspdf.structures import PDFDoc, registry from edspdf.utils.collections import dl_to_ld, flatten, ld_to_dl +from edspdf.utils.filesystem import FileSystem, normalize_fs_path class ParquetReader(BaseReader): @@ -27,27 +29,15 @@ def __init__( path: Union[str, Path], *, read_in_worker: bool, - filesystem: Optional[pyarrow.fs.FileSystem] = None, + filesystem: Optional[FileSystem] = None, ): super().__init__() # Either the filesystem has not been passed # or the path is a URL (e.g. s3://) => we need to infer the filesystem - fs_path = path - if filesystem is None or (isinstance(path, str) and "://" in path): - path = ( - path - if isinstance(path, Path) or "://" in path - else f"file://{os.path.abspath(path)}" - ) - inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path) - filesystem = filesystem or inferred_fs - assert inferred_fs.type_name == filesystem.type_name, ( - f"Protocol {inferred_fs.type_name} in path does not match " - f"filesystem {filesystem.type_name}" - ) + filesystem, path = normalize_fs_path(filesystem, path) self.read_in_worker = read_in_worker self.dataset = pyarrow.dataset.dataset( - fs_path, format="parquet", filesystem=filesystem + path, format="parquet", filesystem=filesystem ) def read_main(self): @@ -60,17 +50,14 @@ def read_main(self): return ( (line, 1) for f in fragments - for batch in f.to_table().to_batches(1024) - for line in dl_to_ld(batch.to_pydict()) + for line in dl_to_ld(f.to_table().to_pydict()) ) def read_worker(self, tasks): if self.read_in_worker: tasks = list( chain.from_iterable( - dl_to_ld(batch.to_pydict()) - for task in tasks - for batch in task.to_table().to_batches(1024) + dl_to_ld(task.to_table().to_pydict()) for task in tasks ) ) return tasks @@ -82,47 +69,55 @@ def read_worker(self, tasks): class ParquetWriter(BaseWriter): def __init__( self, + *, path: Union[str, Path], - num_rows_per_file: int, + num_rows_per_file: Optional[int] = None, overwrite: bool, write_in_worker: bool, accumulate: bool = True, - filesystem: Optional[pyarrow.fs.FileSystem] = None, + filesystem: Optional[FileSystem] = None, ): super().__init__() - fs_path = path - if filesystem is None or (isinstance(path, str) and "://" in path): - path = ( - path - if isinstance(path, Path) or "://" in path - else f"file://{os.path.abspath(path)}" - ) - inferred_fs, fs_path = pyarrow.fs.FileSystem.from_uri(path) - filesystem = filesystem or inferred_fs - assert inferred_fs.type_name == filesystem.type_name, ( - f"Protocol {inferred_fs.type_name} in path does not match " - f"filesystem {filesystem.type_name}" - ) - path = fs_path + filesystem, path = normalize_fs_path(filesystem, path) # Check that filesystem has the same protocol as indicated by path - filesystem.create_dir(fs_path, recursive=True) + looks_like_dir = Path(path).suffix == "" + if looks_like_dir or num_rows_per_file is not None: + num_rows_per_file = num_rows_per_file or 8192 + filesystem.makedirs(path, exist_ok=True) + save_as_dataset = True + else: + assert ( + num_rows_per_file is None + ), "num_rows_per_file should not be set when writing to a single file" + assert ( + write_in_worker is False + ), "write_in_worker cannot be set when writing to a single file" + save_as_dataset = False + num_rows_per_file = sys.maxsize if overwrite is False: - dataset = pyarrow.dataset.dataset( - fs_path, format="parquet", filesystem=filesystem - ) - if len(list(dataset.get_fragments())): - raise FileExistsError( - f"Directory {fs_path} already exists and is not empty. " - "Use overwrite=True to overwrite." + if save_as_dataset: + dataset = pyarrow.dataset.dataset( + path, format="parquet", filesystem=filesystem ) - self.filesystem = filesystem + if len(list(dataset.get_fragments())): + raise FileExistsError( + f"Directory {path} already exists and is not empty. " + "Use overwrite=True to overwrite." + ) + else: + if filesystem.exists(path): + raise FileExistsError( + f"File {path} already exists. Use overwrite=True to overwrite." + ) + self.filesystem: fsspec.AbstractFileSystem = filesystem self.path = path + self.save_as_dataset = save_as_dataset self.write_in_worker = write_in_worker self.batch = [] self.num_rows_per_file = num_rows_per_file self.closed = False self.finalized = False - self.accumulate = accumulate + self.accumulate = (not self.save_as_dataset) and accumulate if not self.accumulate: self.finalize = super().finalize @@ -162,13 +157,22 @@ def finalize(self): return self.write_worker([], last=True) def write_main(self, fragments: Iterable[List[Union[pyarrow.Table, Path]]]): - for table in flatten(fragments): - if not self.write_in_worker: - pyarrow.parquet.write_to_dataset( - table=table, - root_path=self.path, - filesystem=self.filesystem, - ) + tables = list(flatten(fragments)) + if self.save_as_dataset: + for table in tables: + if not self.write_in_worker: + pyarrow.parquet.write_to_dataset( + table=table, + root_path=self.path, + filesystem=self.filesystem, + ) + else: + pyarrow.parquet.write_table( + table=pyarrow.concat_tables(tables), + where=self.path, + filesystem=self.filesystem, + ) + return pyarrow.dataset.dataset( self.path, format="parquet", filesystem=self.filesystem ) @@ -202,7 +206,7 @@ def write_parquet( path: Union[str, Path], *, write_in_worker: bool = False, - num_rows_per_file: int = 1024, + num_rows_per_file: Optional[int] = None, overwrite: bool = False, filesystem: Optional[pyarrow.fs.FileSystem] = None, accumulate: bool = True, @@ -216,7 +220,7 @@ def write_parquet( return data.write( ParquetWriter( - path, + path=path, num_rows_per_file=num_rows_per_file, overwrite=overwrite, write_in_worker=write_in_worker, diff --git a/edspdf/lazy_collection.py b/edspdf/lazy_collection.py index e3b60b9..225fd1d 100644 --- a/edspdf/lazy_collection.py +++ b/edspdf/lazy_collection.py @@ -325,6 +325,37 @@ def to(self, device: Union[str, Optional["torch.device"]] = None): # noqa F821 pipe.to(device) return self + def train(self, mode=True): + """ + Enables training mode on pytorch modules + + Parameters + ---------- + mode: bool + Whether to enable training or not + """ + + class context: + def __enter__(self): + pass + + def __exit__(ctx_self, type, value, traceback): + for name, proc in procs: + proc.train(was_training[name]) + + procs = [x for x in self.torch_components() if hasattr(x[1], "train")] + was_training = {name: proc.training for name, proc in procs} + for name, proc in procs: + proc.train(mode) + + return context() + + def eval(self): + """ + Enables evaluation mode on pytorch modules + """ + return self.train(False) + def worker_copy(self): return LazyCollection( reader=self.reader.worker_copy(), diff --git a/edspdf/pipes/classifiers/trainable.py b/edspdf/pipes/classifiers/trainable.py index 1f71e79..16d9c55 100644 --- a/edspdf/pipes/classifiers/trainable.py +++ b/edspdf/pipes/classifiers/trainable.py @@ -1,7 +1,7 @@ import json import os from pathlib import Path -from typing import Any, Dict, Iterable, Sequence, Set +from typing import Any, Dict, Iterable, List, Sequence, Set import torch import torch.nn.functional as F @@ -201,7 +201,12 @@ def forward(self, batch: Dict) -> Dict: return output - def postprocess(self, docs: Sequence[PDFDoc], output: Dict) -> Sequence[PDFDoc]: + def postprocess( + self, + docs: Sequence[PDFDoc], + output: Dict, + inputs: List[Dict[str, Any]], + ) -> Sequence[PDFDoc]: for b, label in zip( (b for doc in docs for b in doc.text_boxes), output["labels"].tolist(), diff --git a/edspdf/processing/multiprocessing.py b/edspdf/processing/multiprocessing.py index 0fbbe12..1ef7872 100644 --- a/edspdf/processing/multiprocessing.py +++ b/edspdf/processing/multiprocessing.py @@ -4,6 +4,7 @@ import gc import io import logging +import math import multiprocessing import multiprocessing.reduction import os @@ -28,13 +29,14 @@ from typing_extensions import TypedDict from edspdf.lazy_collection import LazyCollection -from edspdf.utils.collections import batchify, flatten +from edspdf.utils.collections import ( + batch_compress_dict, + batchify, + decompress_dict, + flatten, +) -batch_size_fns = { - "content_boxes": lambda batch: sum(len(doc.content_boxes) for doc in batch), - "pages": lambda batch: sum(len(doc.pages) for doc in batch), - "docs": len, -} +from .utils import apply_basic_pipes, batchify_fns, batchify_with_counts doc_size_fns = { "content_boxes": lambda doc: len(doc.content_boxes), @@ -48,21 +50,12 @@ Stage = TypedDict( "Stage", { - "cpu_components": List[Tuple[str, Callable, Dict]], + "cpu_components": List[Tuple[str, Callable, Dict, Any]], "gpu_component": Optional[Any], }, ) -def apply_basic_pipes(docs, pipes): - for name, pipe, kwargs in pipes: - if hasattr(pipe, "batch_process"): - docs = pipe.batch_process(docs) - else: - docs = [pipe(doc, **kwargs) for doc in docs] - return docs - - class ForkingPickler(dill.Pickler): """ ForkingPickler that uses dill instead of pickle to transfer objects between @@ -144,7 +137,76 @@ def revert(): return revert -# Should we check if the multiprocessing module of edspdf +def cpu_count(): # pragma: no cover + """ + Heavily inspired (partially copied) from joblib's loky + (https://github.com/joblib/loky/blob/2c21e/loky/backend/context.py#L83) + by Thomas Moreau and Olivier Grisel. + + Return the number of CPUs we can use to process data in parallel. + + The returned number of CPUs returns the minimum of: + * `os.cpu_count()` + * the CPU affinity settings + * cgroup CPU bandwidth limit (share of total CPU time allowed in a given job) + typically used in containerized environments like Docker + + Note that on Windows, the returned number of CPUs cannot exceed 61 (or 60 for + Python < 3.10), see: + https://bugs.python.org/issue26903. + + It is also always larger or equal to 1. + """ + # Note: os.cpu_count() is allowed to return None in its docstring + os_cpu_count = os.cpu_count() or 1 + if sys.platform == "win32": + # Following loky's windows implementation + + _MAX_WINDOWS_WORKERS = 60 + if sys.version_info >= (3, 8): + from concurrent.futures.process import _MAX_WINDOWS_WORKERS + + if sys.version_info < (3, 10): + _MAX_WINDOWS_WORKERS = _MAX_WINDOWS_WORKERS - 1 + os_cpu_count = min(os_cpu_count, _MAX_WINDOWS_WORKERS) + + cpu_count_affinity = os_cpu_count + try: + cpu_count_affinity = len(os.sched_getaffinity(0)) + except (NotImplementedError, AttributeError): + pass + + # Cgroup CPU bandwidth limit available in Linux since 2.6 kernel + cpu_count_cgroup = os_cpu_count + cpu_max_fname = "/sys/fs/cgroup/cpu.max" + cfs_quota_fname = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us" + cfs_period_fname = "/sys/fs/cgroup/cpu/cpu.cfs_period_us" + if os.path.exists(cpu_max_fname): + # cgroup v2 + # https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html + with open(cpu_max_fname) as fh: + cpu_quota_us, cpu_period_us = fh.read().strip().split() + elif os.path.exists(cfs_quota_fname) and os.path.exists(cfs_period_fname): + # cgroup v1 + # https://www.kernel.org/doc/html/latest/scheduler/sched-bwc.html#management + with open(cfs_quota_fname) as fh: + cpu_quota_us = fh.read().strip() + with open(cfs_period_fname) as fh: + cpu_period_us = fh.read().strip() + else: + cpu_quota_us = "max" + cpu_period_us = 100_000 + + if cpu_quota_us != "max": + cpu_quota_us = int(cpu_quota_us) + cpu_period_us = int(cpu_period_us) + if cpu_quota_us > 0 and cpu_period_us > 0: + cpu_count_cgroup = math.ceil(cpu_quota_us / cpu_period_us) + + return max(1, min(os_cpu_count, cpu_count_affinity, cpu_count_cgroup)) + + +# Should we check if the multiprocessing module of edsnlp # is responsible for this child process before replacing the pickler ? if ( multiprocessing.current_process() != "MainProcess" @@ -161,7 +223,13 @@ def revert(): else lambda *args, **kwargs: None ) -try: # pragma: no cover +if os.environ.get("TORCH_SHARING_STRATEGY"): + try: + torch.multiprocessing.set_sharing_strategy(os.environ["TORCH_SHARING_STRATEGY"]) + except NameError: + pass + +try: import torch # Torch may still be imported as a namespace package, so we can access the @@ -269,8 +337,17 @@ def __init__( self.num_stages = num_stages # noinspection PyUnresolvedReferences - def get_cpu_task(self, idx): - queue_readers = wait([queue._reader for queue in self.cpu_inputs_queues[idx]]) + def get_cpu_task(self, idx, get_instant_active_or_skip: bool = False): + queues = self.cpu_inputs_queues[idx] + if get_instant_active_or_skip: + # Don't get new tasks + queues = queues[1:] + queue_readers = wait( + [queue._reader for queue in queues], + timeout=0 if get_instant_active_or_skip else None, + ) + if len(queue_readers) == 0: + return None, None stage, queue = next( (stage, q) for stage, q in reversed(list(enumerate(self.cpu_inputs_queues[idx]))) @@ -324,15 +401,94 @@ def run(self): # Cannot pass torch tensor during init i think ? otherwise i get # ValueError: bad value(s) in fds_to_keep # mp._prctl_pr_set_pdeathsig(signal.SIGINT) + next_batch_id = self.cpu_idx + new_batch_iterator = None + + def split_task_into_new_batches(task): + nonlocal next_batch_id, new_batch_iterator + task_id, fragments = task + chunks = list(batchify(lc.reader.read_worker(fragments), lc.chunk_size)) + for chunk_idx, docs in enumerate(chunks): + # If we sort by size, we must first create the documents + # to have features against which we will sort + docs = apply_basic_pipes(docs, preprocess_pipes) + + if lc.sort_chunks: + docs.sort( + key=doc_size_fns.get( + lc.sort_chunks, + doc_size_fns["content_boxes"], + ) + ) + + batches = [ + batch + for batch in batchify_fns[lc.batch_by]( + docs, + batch_size=lc.batch_size, + ) + ] + + for batch_idx, batch in enumerate(batches): + assert len(batch) > 0 + batch_id = next_batch_id + + # We mark the task id only for the last batch of a task + # since the purpose of storing the task id is to know + # when the worker has finished processing the task, + # which is true only when the last batch has been + # processed + active_batches[batch_id] = ( + batch, + task_id + if (batch_idx == len(batches) - 1) + and (chunk_idx == len(chunks) - 1) + else None, + None, + ) + next_batch_id += num_cpu + # gpu_idx = None + # batch_id = we have just created a new batch + # result from the last stage = None + if batch_idx == len(batches) - 1 and chunk_idx == len(chunks) - 1: + new_batch_iterator = None + yield 0, (None, batch_id, None) + + new_batch_iterator = None def read_tasks(): - next_batch_id = self.cpu_idx + nonlocal new_batch_iterator + expect_new_tasks = True while expect_new_tasks or len(active_batches) > 0: + # Check that there are no more than `chunk_size` docs being processed. + # If there is still room, we can process new batches + has_room_for_new_batches = ( + sum(len(ab[0]) for ab in active_batches.values()) < lc.chunk_size + ) + + # if new_batch_iterator is not None and len(active_batches) == 0: + # yield next(new_batch_iterator) + # continue + stage, task = self.exchanger.get_cpu_task( idx=self.cpu_idx, + # We don't have to wait for new active batches to come back if: + get_instant_active_or_skip=( + # - we have room for more batches + has_room_for_new_batches + # - and the batch iterator is still active + and new_batch_iterator is not None + ), ) + + # No active batch was returned, and by construction we have room for + # new batches, so we can start a new batch + if stage is None: + yield next(new_batch_iterator) + continue + # stage, task = next(iterator) # Prioritized STOP signal: something bad happened in another process # -> stop listening to input queues and raise StopIteration (return) @@ -348,52 +504,8 @@ def read_tasks(): # If first stage, we receive tasks that may require batching # again => we split them into chunks if stage == 0: - task_id, fragments = task - chunks = list( - batchify(lc.reader.read_worker(fragments), lc.chunk_size) - ) - for chunk_idx, docs in enumerate(chunks): - # If we sort by size, we must first create the documents - # to have features against which we will sort - docs = apply_basic_pipes(docs, preprocess_pipes) - - if lc.sort_chunks: - docs.sort( - key=doc_size_fns.get( - lc.sort_chunks, doc_size_fns["content_boxes"] - ) - ) - - batches = [ - batch - for batch in batchify( - docs, - batch_size=lc.batch_size, - formula=batch_size_fns[lc.batch_by], - ) - ] - - for batch_idx, batch in enumerate(batches): - assert len(batch) > 0 - batch_id = next_batch_id - - # We mark the task id only for the last batch of a task - # since the purpose of storing the task id is to know - # when the worker has finished processing the task, - # which is true only when the last batch has been - # processed - active_batches[batch_id] = ( - batch, - task_id - if (batch_idx == len(batches) - 1) - and (chunk_idx == len(chunks) - 1) - else None, - ) - next_batch_id += num_cpu - # gpu_idx = None - # batch_id = we have just created a new batch - # result from the last stage = None - yield stage, (None, batch_id, None) + new_batch_iterator = split_task_into_new_batches(task) + yield next(new_batch_iterator) else: yield stage, task @@ -401,16 +513,15 @@ def read_tasks(): lc: LazyCollection = load( self.lazy_collection_path, map_location=self.device ) + lc.eval() preprocess_pipes = [] num_cpu = self.exchanger.num_cpu_workers split_into_batches_after = lc.split_into_batches_after - if ( - split_into_batches_after is None - or lc.batch_by != "docs" - or lc.sort_chunks + if split_into_batches_after is None and ( + lc.batch_by != "docs" or lc.sort_chunks ): split_into_batches_after = next( - (p[0] for p in lc.pipeline if p[0] is not None), None + (s[0] for s in lc.pipeline if s[0] is not None), None ) is_before_split = split_into_batches_after is not None @@ -438,29 +549,41 @@ def read_tasks(): self.exchanger.put_results((None, 0, None, None)) for stage, (gpu_idx, batch_id, result) in read_tasks(): - docs, task_id = active_batches.pop(batch_id) + docs, task_id, inputs = active_batches.pop(batch_id) + count = len(docs) for name, pipe, *rest in lc.pipeline: if hasattr(pipe, "enable_cache"): pipe.enable_cache(batch_id) if stage > 0: gpu_pipe = stages[stage - 1]["gpu_component"] - docs = gpu_pipe.postprocess(docs, result) # type: ignore + docs = ( + gpu_pipe.postprocess(docs, result, inputs) + if getattr(gpu_pipe, "postprocess", None) is not None + else result + ) docs = apply_basic_pipes(docs, stages[stage]["cpu_components"]) gpu_pipe: "TrainablePipe" = stages[stage]["gpu_component"] if gpu_pipe is not None: - preprocessed = gpu_pipe.make_batch(docs) # type: ignore - active_batches[batch_id] = (docs, task_id) if gpu_idx is None: gpu_idx = batch_id % len(self.exchanger.gpu_worker_devices) - collated = gpu_pipe.collate(preprocessed) - collated = gpu_pipe.batch_to_device( - collated, - device=self.exchanger.gpu_worker_devices[gpu_idx], - ) + device = self.exchanger.gpu_worker_devices[gpu_idx] + if hasattr(gpu_pipe, "preprocess"): + inputs = [gpu_pipe.preprocess(doc) for doc in docs] + batch = decompress_dict(list(batch_compress_dict(inputs))) + batch = gpu_pipe.collate(batch) + batch = gpu_pipe.batch_to_device(batch, device=device) + else: + batch = gpu_pipe.prepare_batch(docs, device=device) + inputs = None + active_batches[batch_id] = (docs, task_id, inputs) self.exchanger.put_gpu( - item=(self.cpu_idx, batch_id, collated), + item=( + self.cpu_idx, + batch_id, + batch, + ), idx=gpu_idx, stage=stage, ) @@ -471,7 +594,7 @@ def read_tasks(): results, count = ( lc.writer.write_worker(docs) if lc.writer is not None - else (docs, len(docs)) + else (docs, count) ) self.exchanger.put_results( ( @@ -529,6 +652,7 @@ def run(self): # mp._prctl_pr_set_pdeathsig(signal.SIGINT) try: lc = load(self.lazy_collection_path, map_location=self.device) + lc.eval() stage_components = [ pipe # move_to_device(pipe, self.device) @@ -542,7 +666,7 @@ def run(self): # Inform the main process that we are ready self.exchanger.put_results((None, 0, None, None)) - with torch.no_grad(): + with torch.no_grad(), torch.cuda.amp.autocast(): while True: stage, task = self.exchanger.get_gpu_task(self.gpu_idx) if task is None: @@ -597,9 +721,6 @@ def __repr__(self): return f"" -DEFAULT_MAX_CPU_WORKERS = 4 - - def execute_multiprocessing_backend( lc: LazyCollection, ): @@ -637,6 +758,11 @@ def execute_multiprocessing_backend( + !!! warning "Caveat" + + Since workers can produce their results in any order, the order of the results + may not be the same as the order of the input tasks. + """ try: TrainablePipe = sys.modules["edspdf.trainable_pipe"].TrainablePipe @@ -679,6 +805,7 @@ def execute_multiprocessing_backend( and num_gpu_workers > 0 ) + num_cpus = int(os.environ.get("EDSPDF_MAX_CPU_WORKERS") or cpu_count()) num_devices = 0 if requires_gpu: import torch @@ -687,10 +814,20 @@ def execute_multiprocessing_backend( logging.info(f"Number of available devices: {num_devices}") if num_gpu_workers is None: - num_gpu_workers = num_devices + num_gpu_workers = min(num_devices, num_cpus // 2) else: num_gpu_workers = 0 + if "torch" in sys.modules: + try: + import torch.multiprocessing + + os.environ[ + "TORCH_SHARING_STRATEGY" + ] = torch.multiprocessing.get_sharing_strategy() + except ImportError: # pragma: no cover + pass + if any(gpu_steps_candidates): if process_start_method == "fork": warnings.warn( @@ -705,11 +842,11 @@ def execute_multiprocessing_backend( logging.info(f"Switching process start method to {process_start_method}") mp = multiprocessing.get_context(process_start_method) - max_workers = max(min(mp.cpu_count() - num_gpu_workers, DEFAULT_MAX_CPU_WORKERS), 0) + max_cpu_workers = max(num_cpus - num_gpu_workers - 1, 0) num_cpu_workers = ( - (num_gpu_workers or max_workers) + max_cpu_workers if num_cpu_workers is None - else max_workers + num_cpu_workers + 1 + else max_cpu_workers + num_cpu_workers + 1 if num_cpu_workers < 0 else num_cpu_workers ) @@ -751,7 +888,7 @@ def execute_multiprocessing_backend( gpu_worker_devices=gpu_worker_devices, ) - lc = lc.to("cpu") + # lc = lc.to("cpu") cpu_workers = [] gpu_workers = [] @@ -827,7 +964,7 @@ def execute_multiprocessing_backend( if show_progress: from tqdm import tqdm - bar = tqdm(smoothing=0.1, mininterval=5.0) + bar = tqdm(smoothing=0.1, mininterval=1.0) def get_and_process_output(): outputs, count, cpu_idx, output_task_id = next(outputs_iterator) @@ -844,18 +981,10 @@ def get_and_process_output(): def process(): try: - with bar: - for input_task_id, items in enumerate( - batchify( - iterable=inputs_iterator, - batch_size=lc.chunk_size, - drop_last=False, - formula=lambda x: sum(item[1] for item in x), - ) + with bar, lc.eval(): + for input_task_id, (batch, batch_size) in enumerate( + batchify_with_counts(inputs_iterator, lc.chunk_size) ): - batch = [item[0] for item in items] - batch_size = sum(item[1] for item in items) - while all(sum(wl.values()) >= max_workload for wl in active_chunks): yield from get_and_process_output() @@ -913,11 +1042,11 @@ def process(): # with the cleanup of these processes ? for i, worker in enumerate(gpu_workers): # pragma: no cover if worker.is_alive(): - logging.error(f"Killing ") + logging.error(f"Killing ") worker.kill() for i, worker in enumerate(cpu_workers): # pragma: no cover if worker.is_alive(): - logging.error(f"Killing ") + logging.error(f"Killing ") worker.kill() for queue_group in ( @@ -930,4 +1059,4 @@ def process(): queue.cancel_join_thread() gen = process() - return lc.writer.write_main(gen) if lc.writer is not None else flatten(gen) + return flatten(gen) if lc.writer is None else lc.writer.write_main(gen) diff --git a/edspdf/processing/simple.py b/edspdf/processing/simple.py index 7137241..6b70765 100644 --- a/edspdf/processing/simple.py +++ b/edspdf/processing/simple.py @@ -6,29 +6,16 @@ from edspdf.utils.collections import batchify, flatten +from .utils import apply_basic_pipes, batchify_fns + if TYPE_CHECKING: from edspdf.lazy_collection import LazyCollection -batch_size_fns = { - "content_boxes": lambda batch: sum(len(doc.content_boxes) for doc in batch), - "pages": lambda batch: sum(len(doc.pages) for doc in batch), - "docs": len, -} - doc_size_fns = { "content_boxes": lambda doc: len(doc.content_boxes), } -def apply_basic_pipes(docs, pipes): - for name, pipe, kwargs in pipes: - if hasattr(pipe, "batch_process"): - docs = pipe.batch_process(docs) - else: - docs = [pipe(doc, **kwargs) for doc in docs] - return docs - - def execute_simple_backend( lc: LazyCollection, ): @@ -45,11 +32,11 @@ def execute_simple_backend( show_progress = lc.show_progress split_into_batches_after = lc.split_into_batches_after - if split_into_batches_after is None or lc.batch_by != "docs" or lc.sort_chunks: + if split_into_batches_after is None and (lc.batch_by != "docs" or lc.sort_chunks): split_into_batches_after = next( - (p[0] for p in lc.pipeline if p[0] is not None), None + (s[0] for s in lc.pipeline if s[0] is not None), None ) - names = [step[0] for step in lc.pipeline] + [None] + names = [None] + [step[0] for step in lc.pipeline] chunk_components = lc.pipeline[: names.index(split_into_batches_after)] batch_components = lc.pipeline[names.index(split_into_batches_after) :] @@ -60,7 +47,7 @@ def process(): bar = tqdm(smoothing=0.1, mininterval=5.0) - with bar: + with bar, lc.eval(): for docs in batchify( ( subtask @@ -78,16 +65,8 @@ def process(): ) ) - batches = [ - batch - for batch in batchify( - docs, - batch_size=lc.batch_size, - formula=batch_size_fns.get(lc.batch_by, len), - ) - ] - - for batch in batches: + for batch in batchify_fns[lc.batch_by](docs, lc.batch_size): + count = len(batch) with no_grad(), lc.cache(): batch = apply_basic_pipes(batch, batch_components) @@ -98,7 +77,7 @@ def process(): yield result else: if show_progress: - bar.update(len(batch)) + bar.update(count) yield batch if writer is not None: result, count = writer.finalize() diff --git a/edspdf/processing/utils.py b/edspdf/processing/utils.py new file mode 100644 index 0000000..6ce559d --- /dev/null +++ b/edspdf/processing/utils.py @@ -0,0 +1,85 @@ +import types +from typing import Iterable, List, TypeVar + +from edspdf.utils.collections import batchify + + +def apply_basic_pipes(docs, pipes): + for name, pipe, kwargs in pipes: + if hasattr(pipe, "batch_process"): + docs = pipe.batch_process(docs) + else: + results = [] + for doc in docs: + res = pipe(doc, **kwargs) + if isinstance(res, types.GeneratorType): + results.extend(res) + else: + results.append(res) + docs = results + return docs + + +T = TypeVar("T") + + +def batchify_with_counts( + iterable, + batch_size, +): + total = 0 + batch = [] + for item, count in iterable: + if len(batch) > 0 and total + count > batch_size: + yield batch, total + batch = [] + total = 0 + batch.append(item) + total += count + if len(batch) > 0: + yield batch, total + + +def batchify_by_content_boxes( + iterable: Iterable[T], + batch_size: int, + drop_last: bool = False, +) -> Iterable[List[T]]: + batch = [] + total = 0 + for item in iterable: + count = len(item.content_boxes) + if len(batch) > 0 and total + count > batch_size: + yield batch + batch = [] + total = 0 + batch.append(item) + total += count + if len(batch) > 0 and not drop_last: + yield batch + + +def batchify_by_pages( + iterable: Iterable[T], + batch_size: int, + drop_last: bool = False, +) -> Iterable[List[T]]: + batch = [] + total = 0 + for item in iterable: + count = len(item.pages) + if len(batch) > 0 and total + count > batch_size: + yield batch + batch = [] + total = 0 + batch.append(item) + total += count + if len(batch) > 0 and not drop_last: + yield batch + + +batchify_fns = { + "content_boxes": batchify_by_content_boxes, + "pages": batchify_by_pages, + "docs": batchify, +} diff --git a/edspdf/trainable_pipe.py b/edspdf/trainable_pipe.py index e1979ba..1654e9c 100644 --- a/edspdf/trainable_pipe.py +++ b/edspdf/trainable_pipe.py @@ -8,6 +8,7 @@ Dict, Generic, Iterable, + List, Optional, Sequence, Set, @@ -226,7 +227,7 @@ def post_init(self, gold_data: Iterable[PDFDoc], exclude: Set[str]): if hasattr(component, "post_init"): component.post_init(gold_data, exclude=exclude) - def preprocess(self, doc: PDFDoc) -> Dict[str, Any]: + def preprocess(self, doc: PDFDoc, **kwargs) -> Dict[str, Any]: """ Preprocess the document to extract features that will be used by the neural network to perform its predictions. @@ -243,7 +244,7 @@ def preprocess(self, doc: PDFDoc) -> Dict[str, Any]: the document. """ return { - name: component.preprocess(doc) + name: component.preprocess(doc, **kwargs) for name, component in self.named_component_children() } @@ -288,10 +289,11 @@ def batch_to_device( """ return { name: ( - value.to(device) + (value.to(device) if device is not None else value) if hasattr(value, "to") else getattr(self, name).batch_to_device(value, device=device) if hasattr(self, name) + and hasattr(getattr(self, name), "batch_to_device") else value ) for name, value in batch.items() @@ -325,7 +327,8 @@ def make_batch( self, docs: Sequence[PDFDoc], supervision: bool = False, - ) -> Dict[str, Sequence[Any]]: + device: Optional[Union[str, torch.device]] = None, + ) -> BatchInput: """ Convenience method to preprocess a batch of documents and collate them Features corresponding to the same path are grouped together in a list, @@ -337,6 +340,8 @@ def make_batch( Batch of documents supervision: bool Whether to extract supervision features or not + device: Optional[Union[str, torch.device]] + Device to move the tensors to Returns ------- @@ -346,7 +351,10 @@ def make_batch( (self.preprocess_supervised(doc) if supervision else self.preprocess(doc)) for doc in docs ] - return decompress_dict(list(batch_compress_dict(batch))) + batch = decompress_dict(list(batch_compress_dict(batch))) + batch = self.collate(batch) + batch = self.batch_to_device(batch, device=device) + return batch def batch_process(self, docs: Sequence[PDFDoc]) -> Sequence[PDFDoc]: """ @@ -364,20 +372,23 @@ def batch_process(self, docs: Sequence[PDFDoc]) -> Sequence[PDFDoc]: Sequence[PDFDoc] Batch of updated documents """ - device = self.device with torch.no_grad(): - batch = self.make_batch(docs) - inputs = self.collate(batch) - inputs = self.batch_to_device(inputs, device=device) + inputs = [self.preprocess(doc) for doc in docs] + batch = decompress_dict(list(batch_compress_dict(inputs))) + batch = self.collate(batch) + batch = self.batch_to_device(batch, device=self.device) if hasattr(self, "compiled"): - res = self.compiled(inputs) + res = self.compiled(batch) else: - res = self.module_forward(inputs) - docs = self.postprocess(docs, res) + res = self.module_forward(batch) + docs = self.postprocess(docs, res, inputs) return docs def postprocess( - self, docs: Sequence[PDFDoc], batch: BatchOutput + self, + docs: Sequence[PDFDoc], + batch: BatchOutput, + inputs: List[Dict[str, Any]], ) -> Sequence[PDFDoc]: """ Update the documents with the predictions of the neural network, for instance @@ -391,6 +402,8 @@ def postprocess( Batch of documents batch: BatchOutput Batch of predictions, as returned by the forward method + inputs: List[Dict[str, Any]], + List of preprocessed features, as returned by the preprocess method Returns ------- @@ -447,7 +460,7 @@ def to_disk(self, path, exclude: Optional[Set[str]]): overrides[name] = pipe_overrides tensor_dict = { n: p - for n, p in self.named_parameters() + for n, p in (*self.named_parameters(), *self.named_buffers()) if object.__repr__(p) not in exclude } os.makedirs(path, exist_ok=True) diff --git a/tests/core/test_data.py b/tests/core/test_data.py index ee41a0c..3974c97 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -66,11 +66,11 @@ def test_write_data( ) if write_mode == "parquet": docs.write_parquet( - "file://" + str(tmp_path / "parquet" / "test.parquet"), + "file://" + str(tmp_path / "parquet" / "test"), converter=box_converter, write_in_worker=write_in_worker, ) - df = pd.read_parquet("file://" + str(tmp_path / "parquet" / "test.parquet")) + df = pd.read_parquet("file://" + str(tmp_path / "parquet" / "test")) elif write_mode == "pandas": if write_in_worker: pytest.skip()