diff --git a/CHANGES.md b/CHANGES.md index a21bfb5..523c2c5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,8 @@ ## Unreleased +- Added implementation and software tests for `CrateDBCache`, + deriving from `SQLAlchemyCache`. ## v0.0.0 - 2024-12-16 - Make it work diff --git a/langchain_cratedb/__init__.py b/langchain_cratedb/__init__.py index 91f4c93..d6511a9 100644 --- a/langchain_cratedb/__init__.py +++ b/langchain_cratedb/__init__.py @@ -5,6 +5,7 @@ patch_sqlalchemy_dialect() +from langchain_cratedb.cache import CrateDBCache from langchain_cratedb.chat_history import CrateDBChatMessageHistory from langchain_cratedb.loaders import CrateDBLoader from langchain_cratedb.vectorstores import ( @@ -20,6 +21,7 @@ del metadata # optional, avoids polluting the results of dir(__package__) __all__ = [ + "CrateDBCache", "CrateDBChatMessageHistory", "CrateDBLoader", "CrateDBVectorStore", diff --git a/langchain_cratedb/cache.py b/langchain_cratedb/cache.py new file mode 100644 index 0000000..d9866ca --- /dev/null +++ b/langchain_cratedb/cache.py @@ -0,0 +1,18 @@ +import typing as t + +import sqlalchemy as sa +from langchain_community.cache import FullLLMCache, SQLAlchemyCache +from sqlalchemy_cratedb.support import refresh_after_dml + + +class CrateDBCache(SQLAlchemyCache): + """ + CrateDB adapter for LangChain standard / full cache subsystem. + It builds upon SQLAlchemyCache 1:1. + """ + + def __init__( + self, engine: sa.Engine, cache_schema: t.Type[FullLLMCache] = FullLLMCache + ): + refresh_after_dml(engine) + super().__init__(engine, cache_schema) diff --git a/langchain_cratedb/patches.py b/langchain_cratedb/patches.py index 57c7d0a..ae7ac3b 100644 --- a/langchain_cratedb/patches.py +++ b/langchain_cratedb/patches.py @@ -1,3 +1,20 @@ +import warnings + +from sqlalchemy_cratedb.compiler import CrateDDLCompiler + + +def ddl_compiler_visit_create_index(self, create, **kw) -> str: # type: ignore[no-untyped-def] + """ + CrateDB does not support `CREATE INDEX` statements. + """ + warnings.warn( + "CrateDB does not support `CREATE INDEX` statements, " + "they will be omitted when generating DDL statements.", + stacklevel=2, + ) + return "SELECT 1" + + def patch_sqlalchemy_dialect() -> None: """ Fixes `AttributeError: 'CrateCompilerSA20' object has no attribute 'visit_on_conflict_do_update'` @@ -10,6 +27,7 @@ def patch_sqlalchemy_dialect() -> None: CrateCompiler.visit_on_conflict_do_update = PGCompiler.visit_on_conflict_do_update CrateCompiler._on_conflict_target = PGCompiler._on_conflict_target + CrateDDLCompiler.visit_create_index = ddl_compiler_visit_create_index patch_sqlalchemy_dialect() diff --git a/tests/integration_tests/cache/__init__.py b/tests/integration_tests/cache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration_tests/cache/test_standard_mcsa_cache.py b/tests/integration_tests/cache/test_standard_mcsa_cache.py new file mode 100644 index 0000000..df8c48a --- /dev/null +++ b/tests/integration_tests/cache/test_standard_mcsa_cache.py @@ -0,0 +1,111 @@ +""" +Test standard and semantic caching. +Derived from Memcached and SQLAlchemy. + +Source: https://github.com/langchain-ai/langchain/blob/langchain-core%3D%3D0.3.28/libs/community/tests/integration_tests/cache/test_memcached_cache.py +""" + +import pytest +import sqlalchemy as sa +from langchain.globals import get_llm_cache, set_llm_cache +from langchain_core.caches import BaseCache +from langchain_core.outputs import Generation, LLMResult + +from langchain_cratedb import CrateDBCache +from tests.utils import FakeLLM + + +@pytest.fixture() +def cache(engine: sa.Engine) -> BaseCache: + return CrateDBCache(engine=engine) + + +def test_memcached_cache(cache: BaseCache) -> None: + """Test general caching""" + + set_llm_cache(cache) + llm = FakeLLM() + + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo"]) + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + assert output == expected_output + # clear the cache + get_llm_cache().clear() + + +def test_memcached_cache_flush(cache: BaseCache) -> None: + """Test flushing cache""" + + set_llm_cache(cache) + llm = FakeLLM() + + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo"]) + expected_output = LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + assert output == expected_output + # clear the cache + get_llm_cache().clear(delay=0, noreply=False) + + # After cache has been cleared, the result shouldn't be the same + output = llm.generate(["foo"]) + assert output != expected_output + + +def test_sqlalchemy_cache(engine: sa.Engine) -> None: + """Test custom_caching behavior.""" + + from sqlalchemy_cratedb.support import patch_autoincrement_timestamp + + patch_autoincrement_timestamp() + + Base = sa.orm.declarative_base() + + class FulltextLLMCache(Base): # type: ignore + """CrateDB table for fulltext-indexed LLM Cache.""" + + __tablename__ = "llm_cache_fulltext" + # TODO: Original. Can it be converged by adding a polyfill to + # `sqlalchemy-cratedb`? + # id = Column(Integer, Sequence("cache_id"), primary_key=True) + id = sa.Column(sa.BigInteger, server_default=sa.func.now(), primary_key=True) + prompt = sa.Column(sa.String, nullable=False) + llm = sa.Column(sa.String, nullable=False) + idx = sa.Column(sa.Integer) + response = sa.Column(sa.String) + + set_llm_cache(CrateDBCache(engine, FulltextLLMCache)) + get_llm_cache().clear() + + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) + output = llm.generate(["foo", "bar", "foo"]) + expected_cache_output = [Generation(text="foo")] + cache_output = get_llm_cache().lookup("bar", llm_string) + assert cache_output == expected_cache_output + set_llm_cache(None) + expected_generations = [ + [Generation(text="fizz")], + [Generation(text="foo")], + [Generation(text="fizz")], + ] + expected_output = LLMResult( + generations=expected_generations, + llm_output=None, + ) + assert output == expected_output diff --git a/tests/integration_tests/cache/test_standard_sqlite_cache.py b/tests/integration_tests/cache/test_standard_sqlite_cache.py new file mode 100644 index 0000000..f26029f --- /dev/null +++ b/tests/integration_tests/cache/test_standard_sqlite_cache.py @@ -0,0 +1,216 @@ +""" +Test caching for LLMs and ChatModels, derived from tests for SQLite. + +Source: https://github.com/langchain-ai/langchain/blob/langchain-core%3D%3D0.3.28/libs/community/tests/unit_tests/test_cache.py +""" + +from typing import Dict, Generator, List, Union + +import pytest +import sqlalchemy as sa +from _pytest.fixtures import FixtureRequest +from langchain_core.caches import BaseCache, InMemoryCache +from langchain_core.language_models import FakeListChatModel, FakeListLLM +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models.llms import BaseLLM +from langchain_core.load import dumps +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration +from sqlalchemy.orm import Session + +from langchain_cratedb import CrateDBCache + +try: + from sqlalchemy.orm import declarative_base +except ImportError: + from sqlalchemy.ext.declarative import declarative_base # noqa: F401 + +from langchain.globals import get_llm_cache, set_llm_cache +from langchain_core.outputs import Generation + + +@pytest.fixture(params=["memory", "cratedb"]) +def cache(request: FixtureRequest, engine: sa.Engine) -> BaseCache: + if request.param == "memory": + return InMemoryCache() + elif request.param == "cratedb": + return CrateDBCache(engine=engine) + else: + raise NotImplementedError(f"Cache type not implemented: {request.param}") + + +@pytest.fixture(autouse=True) +def set_cache_and_teardown(cache: BaseCache) -> Generator[None, None, None]: + # Will be run before each test + set_llm_cache(cache) + if llm_cache := get_llm_cache(): + llm_cache.clear() + else: + raise ValueError("Cache not set. This should never happen.") + + yield + + # Will be run after each test + if llm_cache := get_llm_cache(): + llm_cache.clear() + set_llm_cache(None) + else: + raise ValueError("Cache not set. This should never happen.") + + +async def test_llm_caching() -> None: + prompt = "How are you?" + response = "Test response" + cached_response = "Cached test response" + llm = FakeListLLM(responses=[response]) + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( + prompt=prompt, + llm_string=create_llm_string(llm), + return_val=[Generation(text=cached_response)], + ) + assert llm.invoke(prompt) == cached_response + # async test + await llm_cache.aupdate( + prompt=prompt, + llm_string=create_llm_string(llm), + return_val=[Generation(text=cached_response)], + ) + assert await llm.ainvoke(prompt) == cached_response + else: + raise ValueError( + "The cache not set. This should never happen, as the pytest fixture " + "`set_cache_and_teardown` always sets the cache." + ) + + +def test_old_sqlite_llm_caching() -> None: + llm_cache = get_llm_cache() + if isinstance(llm_cache, CrateDBCache): + prompt = "How are you?" + response = "Test response" + cached_response = "Cached test response" + llm = FakeListLLM(responses=[response]) + items = [ + llm_cache.cache_schema( + prompt=prompt, + llm=create_llm_string(llm), + response=cached_response, + idx=0, + ) + ] + with Session(llm_cache.engine) as session, session.begin(): + for item in items: + session.merge(item) + assert llm.invoke(prompt) == cached_response + + +async def test_chat_model_caching() -> None: + prompt: List[BaseMessage] = [HumanMessage(content="How are you?")] + response = "Test response" + cached_response = "Cached test response" + cached_message = AIMessage(content=cached_response) + llm = FakeListChatModel(responses=[response]) + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(), + return_val=[ChatGeneration(message=cached_message)], + ) + result = llm.invoke(prompt) + assert isinstance(result, AIMessage) + assert result.content == cached_response + + # async test + await llm_cache.aupdate( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(), + return_val=[ChatGeneration(message=cached_message)], + ) + result = await llm.ainvoke(prompt) + assert isinstance(result, AIMessage) + assert result.content == cached_response + else: + raise ValueError( + "The cache not set. This should never happen, as the pytest fixture " + "`set_cache_and_teardown` always sets the cache." + ) + + +async def test_chat_model_caching_params() -> None: + prompt: List[BaseMessage] = [HumanMessage(content="How are you?")] + response = "Test response" + cached_response = "Cached test response" + cached_message = AIMessage(content=cached_response) + llm = FakeListChatModel(responses=[response]) + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(functions=[]), + return_val=[ChatGeneration(message=cached_message)], + ) + result = llm.invoke(prompt, functions=[]) + result_no_params = llm.invoke(prompt) + assert isinstance(result, AIMessage) + assert result.content == cached_response + assert isinstance(result_no_params, AIMessage) + assert result_no_params.content == response + + # async test + await llm_cache.aupdate( + prompt=dumps(prompt), + llm_string=llm._get_llm_string(functions=[]), + return_val=[ChatGeneration(message=cached_message)], + ) + result = await llm.ainvoke(prompt, functions=[]) + result_no_params = await llm.ainvoke(prompt) + assert isinstance(result, AIMessage) + assert result.content == cached_response + assert isinstance(result_no_params, AIMessage) + assert result_no_params.content == response + else: + raise ValueError( + "The cache not set. This should never happen, as the pytest fixture " + "`set_cache_and_teardown` always sets the cache." + ) + + +async def test_llm_cache_clear() -> None: + prompt = "How are you?" + expected_response = "Test response" + cached_response = "Cached test response" + llm = FakeListLLM(responses=[expected_response]) + if llm_cache := get_llm_cache(): + # sync test + llm_cache.update( + prompt=prompt, + llm_string=create_llm_string(llm), + return_val=[Generation(text=cached_response)], + ) + llm_cache.clear() + response = llm.invoke(prompt) + assert response == expected_response + + # async test + await llm_cache.aupdate( + prompt=prompt, + llm_string=create_llm_string(llm), + return_val=[Generation(text=cached_response)], + ) + await llm_cache.aclear() + response = await llm.ainvoke(prompt) + assert response == expected_response + else: + raise ValueError( + "The cache not set. This should never happen, as the pytest fixture " + "`set_cache_and_teardown` always sets the cache." + ) + + +def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str: + _dict: Dict = llm.dict() + _dict["stop"] = None + return str(sorted([(k, v) for k, v in _dict.items()])) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..5cc876b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,64 @@ +""" +Fake LLM wrapper for testing purposes. + +Source: https://github.com/langchain-ai/langchain/blob/langchain-core%3D%3D0.3.28/libs/langchain/tests/unit_tests/llms/fake_llm.py +""" + +from typing import Any, Dict, List, Mapping, Optional, cast + +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from pydantic import model_validator + + +class FakeLLM(LLM): + """Fake LLM wrapper for testing purposes.""" + + queries: Optional[Mapping] = None + sequential_responses: Optional[bool] = False + response_index: int = 0 + + @model_validator(mode="before") + @classmethod + def check_queries_required(cls, values: dict) -> dict: + if values.get("sequential_response") and not values.get("queries"): + raise ValueError( + "queries is required when sequential_response is set to True" + ) + return values + + def get_num_tokens(self, text: str) -> int: + """Return number of tokens.""" + return len(text.split()) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "fake" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + if self.sequential_responses: + return self._get_next_response_in_sequence + if self.queries is not None: + return self.queries[prompt] + if stop is None: + return "foo" + else: + return "bar" + + @property + def _identifying_params(self) -> Dict[str, Any]: + return {} + + @property + def _get_next_response_in_sequence(self) -> str: + queries = cast(Mapping, self.queries) + response = queries[list(queries.keys())[self.response_index]] + self.response_index = self.response_index + 1 + return response