From 1c0c0955332e060e73b309a4d98c46ffd87d5c35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 15 Feb 2024 23:36:01 +0100 Subject: [PATCH] fix: support empty tensors in cnn & transformer --- edspdf/pipes/embeddings/huggingface_embedding.py | 9 +++++---- tests/conftest.py | 14 -------------- tests/pipes/embeddings/test_custom.py | 9 ++++++++- tests/pipes/embeddings/test_huggingface.py | 8 ++++++-- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/edspdf/pipes/embeddings/huggingface_embedding.py b/edspdf/pipes/embeddings/huggingface_embedding.py index e2c56e39..58a3e3ad 100644 --- a/edspdf/pipes/embeddings/huggingface_embedding.py +++ b/edspdf/pipes/embeddings/huggingface_embedding.py @@ -309,14 +309,15 @@ def collate(self, batch): "line_window_offsets_flat": line_window_offsets_flat, } if self.use_image: - collated["pixel_values"] = torch.stack( + collated["pixel_values"] = torch.as_tensor( [ - torch.from_numpy(page_pixels) + page_pixels for sample_pages in batch["pixel_values"] for page_pixels in sample_pages ], - dim=0, - ).repeat_interleave(torch.as_tensor(windows_count_per_page), dim=0) + ).repeat_interleave( + torch.as_tensor(windows_count_per_page, dtype=torch.long), dim=0 + ) return collated def forward(self, batch): diff --git a/tests/conftest.py b/tests/conftest.py index 289e7831..dc274392 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,4 @@ -import copy import os -from functools import lru_cache from pathlib import Path import pytest @@ -53,18 +51,6 @@ def error_pdf(): return path.read_bytes() -@lru_cache(maxsize=1) -def make_pdfdoc(pdf): - from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor - - return PdfMinerExtractor(render_pages=True)(pdf) - - -@fixture() -def pdfdoc(pdf): - return copy.deepcopy(make_pdfdoc(pdf)) - - @fixture(scope="session") def dummy_dataset(tmpdir_factory, pdf): tmp_path = tmpdir_factory.mktemp("datasets") diff --git a/tests/pipes/embeddings/test_custom.py b/tests/pipes/embeddings/test_custom.py index 31f5a6f3..b0571177 100644 --- a/tests/pipes/embeddings/test_custom.py +++ b/tests/pipes/embeddings/test_custom.py @@ -3,9 +3,10 @@ from edspdf.pipes.embeddings.embedding_combiner import EmbeddingCombiner from edspdf.pipes.embeddings.simple_text_embedding import SimpleTextEmbedding from edspdf.pipes.embeddings.sub_box_cnn_pooler import SubBoxCNNPooler +from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor -def test_custom_embedding(pdfdoc, tmp_path): +def test_custom_embedding(pdf, error_pdf, tmp_path): embedding = BoxTransformer( num_heads=4, dropout_p=0.1, @@ -35,8 +36,14 @@ def test_custom_embedding(pdfdoc, tmp_path): ), ) str(embedding) + + extractor = PdfMinerExtractor(render_pages=True) + pdfdoc = extractor(pdf) pdfdoc.text_boxes[0].text = "Very long word of 150 letters : " + "x" * 150 embedding.post_init([pdfdoc], set()) embedding(pdfdoc) embedding.save_extra_data(tmp_path, set()) embedding.load_extra_data(tmp_path, set()) + + # Test empty document + embedding(extractor(error_pdf)) diff --git a/tests/pipes/embeddings/test_huggingface.py b/tests/pipes/embeddings/test_huggingface.py index e82f386d..3e619b4e 100644 --- a/tests/pipes/embeddings/test_huggingface.py +++ b/tests/pipes/embeddings/test_huggingface.py @@ -1,7 +1,8 @@ from edspdf.pipes.embeddings.huggingface_embedding import HuggingfaceEmbedding +from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor -def test_huggingface_embedding(pdfdoc): +def test_huggingface_embedding(pdf, error_pdf): embedding = HuggingfaceEmbedding( pipeline=None, name="huggingface", @@ -15,4 +16,7 @@ def test_huggingface_embedding(pdfdoc): "height": embedding.hf_model.config.input_size, "width": embedding.hf_model.config.input_size, } - embedding(pdfdoc) + + extractor = PdfMinerExtractor(render_pages=True) + embedding(extractor(pdf)) + embedding(extractor(error_pdf))