Skip to content

Commit

Permalink
Merge pull request #319 from aurelio-labs/james/update-pinecone-test
Browse files Browse the repository at this point in the history
fix: add skip to vit
  • Loading branch information
jamescalam authored Jun 12, 2024
2 parents 5120edc + c4685d8 commit 1916fec
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/unit/encoders/test_vit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -48,29 +49,44 @@ def test_vit_encoder__import_errors_torchvision(self, mocker):
with pytest.raises(ImportError):
VitEncoder()

@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_initialization(self):
assert vit_encoder.name == test_model_name
assert vit_encoder.type == "huggingface"
assert vit_encoder.score_threshold == 0.5
assert vit_encoder.device == device

@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call(self, dummy_pil_image):
encoded_images = vit_encoder([dummy_pil_image] * 3)

assert len(encoded_images) == 3
assert set(map(len, encoded_images)) == {embed_dim}

@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image])

assert len(encoded_images) == 2
assert set(map(len, encoded_images)) == {embed_dim}

@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_process_images_device(self, dummy_pil_image):
imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"]

assert imgs.device.type == device

@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img):
rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img)

Expand Down

0 comments on commit 1916fec

Please sign in to comment.