Skip to content

Commit

Permalink
test(embeddings): add cache directory tests
Browse files Browse the repository at this point in the history
This commit adds two tests for the cache_embeddings decorator in the
EmbeddingsCache class. The first test checks that the cache directory is
created when the cache is enabled. The second test checks that the cache
directory is not created when the cache is disabled. These tests help
ensure that the cache_embeddings decorator behaves as expected.
  • Loading branch information
Pouyanpi committed Sep 2, 2024
1 parent 8becd10 commit 87e12ce
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions tests/test_cache_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from typing import List
Expand Down Expand Up @@ -177,3 +178,55 @@ async def test_cache_embeddings():
texts,
[[104.0, 101.0, 108.0, 108.0, 111.0], [119.0, 111.0, 114.0, 108.0, 100.0]],
)


class TestClass:
def __init__(self, cache_config):
self._cache_config = cache_config

@property
def cache_config(self):
return self._cache_config

@cache_embeddings
async def get_embeddings(self, texts):
return [[float(ord(c)) for c in text] for text in texts]


@pytest.mark.asyncio
async def test_cache_dir_created():
with tempfile.TemporaryDirectory() as temp_dir:
cache_config = EmbeddingsCacheConfig(
enabled=True,
key_generator="md5",
store="filesystem",
store_config={"cache_dir": os.path.join(temp_dir, "exist")},
)

test_class = TestClass(cache_config)

await test_class.get_embeddings(["test"])

# Assert that the cache directory exists
assert os.path.exists(cache_config.store_config["cache_dir"])


@pytest.mark.asyncio
async def test_cache_dir_not_created():
with tempfile.TemporaryDirectory() as temp_dir:
cache_config = EmbeddingsCacheConfig(
enabled=False,
key_generator="md5",
store="filesystem",
store_config={"cache_dir": os.path.join(temp_dir, "exist")},
)

test_class = TestClass(cache_config)

test_class.cache_config.store_config["cache_dir"] = os.path.join(
temp_dir, "nonexistent"
)

await test_class.get_embeddings(["test"])

assert not os.path.exists(os.path.join(temp_dir, "nonexistent"))

0 comments on commit 87e12ce

Please sign in to comment.