Skip to content

Commit

Permalink
fix: support empty tensors in cnn & transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Feb 15, 2024
1 parent 86f771d commit 1c0c095
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 21 deletions.
9 changes: 5 additions & 4 deletions edspdf/pipes/embeddings/huggingface_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,15 @@ def collate(self, batch):
"line_window_offsets_flat": line_window_offsets_flat,
}
if self.use_image:
collated["pixel_values"] = torch.stack(
collated["pixel_values"] = torch.as_tensor(
[
torch.from_numpy(page_pixels)
page_pixels
for sample_pages in batch["pixel_values"]
for page_pixels in sample_pages
],
dim=0,
).repeat_interleave(torch.as_tensor(windows_count_per_page), dim=0)
).repeat_interleave(
torch.as_tensor(windows_count_per_page, dtype=torch.long), dim=0
)
return collated

def forward(self, batch):
Expand Down
14 changes: 0 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import copy
import os
from functools import lru_cache
from pathlib import Path

import pytest
Expand Down Expand Up @@ -53,18 +51,6 @@ def error_pdf():
return path.read_bytes()


@lru_cache(maxsize=1)
def make_pdfdoc(pdf):
from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor

return PdfMinerExtractor(render_pages=True)(pdf)


@fixture()
def pdfdoc(pdf):
return copy.deepcopy(make_pdfdoc(pdf))


@fixture(scope="session")
def dummy_dataset(tmpdir_factory, pdf):
tmp_path = tmpdir_factory.mktemp("datasets")
Expand Down
9 changes: 8 additions & 1 deletion tests/pipes/embeddings/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from edspdf.pipes.embeddings.embedding_combiner import EmbeddingCombiner
from edspdf.pipes.embeddings.simple_text_embedding import SimpleTextEmbedding
from edspdf.pipes.embeddings.sub_box_cnn_pooler import SubBoxCNNPooler
from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor


def test_custom_embedding(pdfdoc, tmp_path):
def test_custom_embedding(pdf, error_pdf, tmp_path):
embedding = BoxTransformer(
num_heads=4,
dropout_p=0.1,
Expand Down Expand Up @@ -35,8 +36,14 @@ def test_custom_embedding(pdfdoc, tmp_path):
),
)
str(embedding)

extractor = PdfMinerExtractor(render_pages=True)
pdfdoc = extractor(pdf)
pdfdoc.text_boxes[0].text = "Very long word of 150 letters : " + "x" * 150
embedding.post_init([pdfdoc], set())
embedding(pdfdoc)
embedding.save_extra_data(tmp_path, set())
embedding.load_extra_data(tmp_path, set())

# Test empty document
embedding(extractor(error_pdf))
8 changes: 6 additions & 2 deletions tests/pipes/embeddings/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from edspdf.pipes.embeddings.huggingface_embedding import HuggingfaceEmbedding
from edspdf.pipes.extractors.pdfminer import PdfMinerExtractor


def test_huggingface_embedding(pdfdoc):
def test_huggingface_embedding(pdf, error_pdf):
embedding = HuggingfaceEmbedding(
pipeline=None,
name="huggingface",
Expand All @@ -15,4 +16,7 @@ def test_huggingface_embedding(pdfdoc):
"height": embedding.hf_model.config.input_size,
"width": embedding.hf_model.config.input_size,
}
embedding(pdfdoc)

extractor = PdfMinerExtractor(render_pages=True)
embedding(extractor(pdf))
embedding(extractor(error_pdf))

0 comments on commit 1c0c095

Please sign in to comment.