From 85b02eb16e646074be15f454a3df0c774bebb694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 16 Feb 2024 11:56:18 +0100 Subject: [PATCH] test: improve coverage --- edspdf/pipeline.py | 2 +- edspdf/trainable_pipe.py | 5 ----- tests/core/test_pipeline.py | 13 +++++++------ 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/edspdf/pipeline.py b/edspdf/pipeline.py index d320fd5..83ecec1 100644 --- a/edspdf/pipeline.py +++ b/edspdf/pipeline.py @@ -360,7 +360,7 @@ def to_doc(doc): backend = accelerator elif isinstance(accelerator, dict): kwargs = dict(accelerator) - backend = accelerator.pop("@accelerator", "simple") + backend = kwargs.pop("@accelerator", "simple") elif "Accelerator" in type(accelerator).__name__: backend = ( "multiprocessing" diff --git a/edspdf/trainable_pipe.py b/edspdf/trainable_pipe.py index 2b1c23a..e1979ba 100644 --- a/edspdf/trainable_pipe.py +++ b/edspdf/trainable_pipe.py @@ -203,11 +203,6 @@ def named_component_children(self): if isinstance(module, TrainablePipe): yield name, module - def named_component_modules(self): - for name, module in self.named_modules(): - if isinstance(module, TrainablePipe): - yield name, module - def post_init(self, gold_data: Iterable[PDFDoc], exclude: Set[str]): """ This method completes the attributes of the component, by looking at some diff --git a/tests/core/test_pipeline.py b/tests/core/test_pipeline.py index e4cc233..637a76a 100644 --- a/tests/core/test_pipeline.py +++ b/tests/core/test_pipeline.py @@ -346,12 +346,13 @@ def error_pipe(doc: PDFDoc): def test_deprecated_multiprocessing_gpu_stub(frozen_pipeline, pdf, letter_pdf): edspdf.accelerators.multiprocessing.MAX_NUM_PROCESSES = 2 - accelerator = edspdf.accelerators.multiprocessing.MultiprocessingAccelerator( - batch_size=2, - num_gpu_workers=1, - num_cpu_workers=1, - gpu_worker_devices=["cpu"], - ) + accelerator = { + "@accelerator": "multiprocessing", + "batch_size": 2, + "num_gpu_workers": 1, + "num_cpu_workers": 1, + "gpu_worker_devices": ["cpu"], + } list( frozen_pipeline.pipe( chain.from_iterable(