Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Batch encoding for TEI encoder #423

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
More pytests for TestHFEndpointEncoder to increase coverage.
Siraj-Aizlewood committed Oct 4, 2024
commit f99539b83c096e28cf685206a3b8f2c10a6297c5
35 changes: 35 additions & 0 deletions tests/unit/encoders/test_hfendpointencoder.py
Original file line number Diff line number Diff line change
@@ -83,3 +83,38 @@ def test_initialization_failure_query_exception(self, requests_mock, mocker):
"HuggingFace endpoint client failed to initialize. Error: Initialization error"
in str(exc_info.value)
)

def test_no_embeddings_returned(self, encoder, requests_mock):
# Mock the response to return an empty list, simulating no embeddings
requests_mock.post(
"https://api-inference.huggingface.co/models/bert-base-uncased",
json=[],
status_code=200,
)
with pytest.raises(ValueError) as exc_info:
encoder(["Hello World!"])
assert "No embeddings returned from the query." in str(exc_info.value)

def test_no_embeddings_for_batch(self, encoder, requests_mock):
# Mock the response to simulate a server error
requests_mock.post(
"https://api-inference.huggingface.co/models/bert-base-uncased",
text="Error",
status_code=500,
)
with pytest.raises(ValueError) as exc_info:
encoder(["Hello World!"])
assert (
"No embeddings returned for batch. Error: Query failed with status 500: Error"
in str(exc_info.value)
)

def test_embeddings_extend(self, encoder, requests_mock):
# Mock the response to return a list of embeddings
requests_mock.post(
"https://api-inference.huggingface.co/models/bert-base-uncased",
json=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
status_code=200,
)
embeddings = encoder(["Hello World!", "Test"])
assert embeddings == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]