diff --git a/CHANGES.md b/CHANGES.md index 523c2c5..bec770b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,7 +3,8 @@ ## Unreleased - Added implementation and software tests for `CrateDBCache`, - deriving from `SQLAlchemyCache`. + deriving from `SQLAlchemyCache`, and `CrateDBSemanticCache`, + building upon `CrateDBVectorStore`. ## v0.0.0 - 2024-12-16 - Make it work diff --git a/examples/cache.py b/examples/cache.py new file mode 100644 index 0000000..b45d27e --- /dev/null +++ b/examples/cache.py @@ -0,0 +1,129 @@ +# ruff: noqa: T201 +""" +Use CrateDB to cache LLM prompts and responses. + +The standard / full cache avoids invoking the LLM when the supplied +prompt is exactly the same as one encountered already. + +The semantic cache allows users to retrieve cached prompts based on semantic +similarity between the user input and previously cached inputs. + +When turning on the cache, redundant LLM conversations don't need +to talk to the LLM (API), so they can also work offline. +""" + +import sqlalchemy as sa +from langchain.globals import set_llm_cache +from langchain_openai import ChatOpenAI, OpenAIEmbeddings + +from langchain_cratedb import CrateDBCache, CrateDBSemanticCache + + +def standard_cache(): + """ + Demonstrate LangChain standard cache with CrateDB. + """ + + # Configure cache. + engine = sa.create_engine("crate://crate@localhost:4200/?schema=testdrive") + set_llm_cache(CrateDBCache(engine)) + + # Invoke LLM conversation. + llm = ChatOpenAI( + # model_name="gpt-3.5-turbo", + # model_name="gpt-4o-mini", + model_name="chatgpt-4o-latest", + temperature=0.7, + ) + print() + print("Asking with standard cache:") + answer = llm.invoke("What is the answer to everything?") + print(answer.content) + + # Turn off cache. + set_llm_cache(None) + + +def semantic_cache(): + """ + Demonstrate LangChain semantic cache with CrateDB. + """ + + # Configure LLM models. + # model_name_embedding = "text-embedding-ada-002" + model_name_embedding = "text-embedding-3-small" + # model_name_embedding = "text-embedding-3-large" + + # model_name_chat = "gpt-3.5-turbo" + # model_name_chat = "gpt-4o-mini" + model_name_chat = "chatgpt-4o-latest" + + # Configure embeddings. + embeddings = OpenAIEmbeddings(model=model_name_embedding) + + # Configure cache. + engine = sa.create_engine("crate://crate@localhost:4200/?schema=testdrive") + set_llm_cache( + CrateDBSemanticCache( + embedding=embeddings, + connection=engine, + search_threshold=1.0, + ) + ) + + # Invoke LLM conversation. + llm = ChatOpenAI( + model_name=model_name_chat, + ) + print() + print("Asking with semantic cache:") + answer = llm.invoke("What is the answer to everything?") + print(answer.content) + + # Turn off cache. + set_llm_cache(None) + + +if __name__ == "__main__": + standard_cache() + semantic_cache() + + +""" +What is the answer to everything? + +Date: 2024-12-23 + +## gpt-3.5-turbo +The answer to everything is subjective and may vary depending on individual +beliefs or philosophies. Some may say that love is the answer to everything, +while others may say that knowledge or self-awareness is the key. Ultimately, +the answer to everything may be different for each person and can only be +discovered through personal reflection and introspection. + +## gpt-4o-mini +The answer to the ultimate question of life, the universe, and everything, +according to Douglas Adams' "The Hitchhiker's Guide to the Galaxy", is +famously given as the number 42. However, the context and meaning behind +that answer remains a philosophical and humorous mystery. In a broader +sense, different people and cultures may have various interpretations of +what the "answer to everything" truly is, often reflecting their beliefs, +values, and experiences. + +## chatgpt-4o-latest, pure +Ah, you're referencing the famous answer from Douglas Adams' +*The Hitchhiker's Guide to the Galaxy*! In the book, the supercomputer +Deep Thought determines that the "Answer to the Ultimate Question of +Life, the Universe, and Everything" is **42**. +Of course, the real kicker is that no one actually knows what the Ultimate +Question is. So, while 42 is the answer, its true meaning remains a cosmic +mystery! 😊 + +## chatgpt-4o-latest, with text-embedding-3-small embeddings +Ah, you're referring to the famous answer from Douglas Adams' +*The Hitchhiker's Guide to the Galaxy*! The answer to the ultimate question +of life, the universe, and everything is **42**. However, as the story +humorously points out, the actual *question* remains unknown. 😊 +If you're looking for a deeper or more philosophical answer, feel free to +elaborate! +""" diff --git a/langchain_cratedb/__init__.py b/langchain_cratedb/__init__.py index d6511a9..7f6322d 100644 --- a/langchain_cratedb/__init__.py +++ b/langchain_cratedb/__init__.py @@ -5,7 +5,7 @@ patch_sqlalchemy_dialect() -from langchain_cratedb.cache import CrateDBCache +from langchain_cratedb.cache import CrateDBCache, CrateDBSemanticCache from langchain_cratedb.chat_history import CrateDBChatMessageHistory from langchain_cratedb.loaders import CrateDBLoader from langchain_cratedb.vectorstores import ( @@ -24,6 +24,7 @@ "CrateDBCache", "CrateDBChatMessageHistory", "CrateDBLoader", + "CrateDBSemanticCache", "CrateDBVectorStore", "CrateDBVectorStoreMultiCollection", "__version__", diff --git a/langchain_cratedb/cache.py b/langchain_cratedb/cache.py index d9866ca..ccd8c7d 100644 --- a/langchain_cratedb/cache.py +++ b/langchain_cratedb/cache.py @@ -1,9 +1,16 @@ import typing as t import sqlalchemy as sa -from langchain_community.cache import FullLLMCache, SQLAlchemyCache +from langchain_community.cache import FullLLMCache, SQLAlchemyCache, _hash +from langchain_core.caches import RETURN_VAL_TYPE, BaseCache +from langchain_core.embeddings import Embeddings +from langchain_core.load import dumps, loads +from langchain_core.outputs import Generation from sqlalchemy_cratedb.support import refresh_after_dml +from langchain_cratedb.vectorstores import CrateDBVectorStore +from langchain_cratedb.vectorstores.main import DBConnection + class CrateDBCache(SQLAlchemyCache): """ @@ -16,3 +23,146 @@ def __init__( ): refresh_after_dml(engine) super().__init__(engine, cache_schema) + + +class CrateDBSemanticCache(BaseCache): + """ + CrateDB adapter for LangChain semantic cache subsystem. + It uses CrateDBVectorStore as a backend. + """ + + def __init__( + self, + embedding: Embeddings, + *, + connection: t.Union[ + None, DBConnection, sa.Engine, sa.ext.asyncio.AsyncEngine, str + ] = None, + cache_table_prefix: str = "cache_", + search_threshold: float = 0.2, + **kwargs: t.Any, + ): + """Initialize with necessary components. + + Args: + embedding (Embeddings): A text embedding model. + cache_table_prefix (str, optional): Prefix for the cache table name. + Defaults to "cache_". + search_threshold (float, optional): The minimum similarity score for + a search result to be considered a match. Defaults to 0.2. + + Examples: + Basic Usage: + + .. code-block:: python + + import langchain + from langchain_cratedb import CrateDBSemanticCache + from langchain.embeddings import OpenAIEmbeddings + + langchain.llm_cache = CrateDBSemanticCache( + embedding=OpenAIEmbeddings(), + host="https://user:password@127.0.0.1:4200/?schema=testdrive" + ) + + Advanced Usage: + + .. code-block:: python + + import langchain + from langchain_cratedb import CrateDBSemanticCache + from langchain.embeddings import OpenAIEmbeddings + + langchain.llm_cache = = CrateDBSemanticCache( + embeddings=OpenAIEmbeddings(), + host="127.0.0.1", + port=4200, + user="user", + password="password", + database="crate", + ) + """ + + self._cache_dict: t.Dict[str, CrateDBVectorStore] = {} + self.embedding = embedding + self.connection = connection + self.cache_table_prefix = cache_table_prefix + self.search_threshold = search_threshold + + # Pass the rest of the kwargs to the connection. + self.connection_kwargs = kwargs + + def _index_name(self, llm_string: str) -> str: + hashed_index = _hash(llm_string) + return f"{self.cache_table_prefix}{hashed_index}" + + def _get_llm_cache(self, llm_string: str) -> CrateDBVectorStore: + index_name = self._index_name(llm_string) + + # return vectorstore client for the specific llm string + if index_name not in self._cache_dict: + vs = self._cache_dict[index_name] = CrateDBVectorStore( + embeddings=self.embedding, + connection=self.connection, + collection_name=index_name, + **self.connection_kwargs, + ) + _embedding = self.embedding.embed_query(text="test") + vs._init_models(_embedding) + vs.create_tables_if_not_exists() + llm_cache = self._cache_dict[index_name] + llm_cache.create_collection() + return llm_cache + + def lookup(self, prompt: str, llm_string: str) -> t.Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + llm_cache = self._get_llm_cache(llm_string) + generations: t.List = [] + # Read from a Hash + results = llm_cache.similarity_search_with_score( + query=prompt, + k=1, + ) + """ + from langchain_postgres.vectorstores import DistanceStrategy + if llm_cache.distance_strategy != DistanceStrategy.EUCLIDEAN: + raise NotImplementedError(f"CrateDB's vector store only implements Euclidean distance. " + f"Your selection was: {llm_cache.distance_strategy}") + """ # noqa: E501 + if results: + for document_score in results: + if document_score[1] <= self.search_threshold: + generations.extend(loads(document_score[0].metadata["return_val"])) + return generations if generations else None + + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: + """Update cache based on prompt and llm_string.""" + for gen in return_val: + if not isinstance(gen, Generation): + raise ValueError( + "CrateDBSemanticCache only supports caching of " + f"normal LLM generations, got {type(gen)}" + ) + llm_cache = self._get_llm_cache(llm_string) + metadata = { + "llm_string": llm_string, + "prompt": prompt, + "return_val": dumps([g for g in return_val]), + } + llm_cache.add_texts(texts=[prompt], metadatas=[metadata]) + + def clear(self, **kwargs: t.Any) -> None: + """Clear semantic cache for a given llm_string.""" + if "llm_string" in kwargs: + index_name = self._index_name(kwargs["llm_string"]) + if index_name in self._cache_dict: + vs = self._cache_dict[index_name] + with vs._make_sync_session() as session: + collection = vs.get_collection(session) + collection.embeddings.clear() + session.commit() + del self._cache_dict[index_name] + else: + raise NotImplementedError( + "Clearing cache elements without constraints is not implemented yet" + ) diff --git a/tests/integration_tests/cache/fake_embeddings.py b/tests/integration_tests/cache/fake_embeddings.py new file mode 100644 index 0000000..02177ce --- /dev/null +++ b/tests/integration_tests/cache/fake_embeddings.py @@ -0,0 +1,86 @@ +""" +Fake Embedding class for testing purposes. + +Source: https://github.com/langchain-ai/langchain/blob/langchain-core%3D%3D0.3.28/libs/community/tests/integration_tests/cache/fake_embeddings.py +""" + +import math +from typing import List + +from langchain_core.embeddings import Embeddings + +fake_texts = ["foo", "bar", "baz"] + + +class FakeEmbeddings(Embeddings): + """Fake embeddings functionality for testing.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return simple embeddings. + Embeddings encode each text as its index.""" + return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + return self.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + """Return constant query embeddings. + Embeddings are identical to embed_documents(texts)[0]. + Distance to each text will be that text's index, + as it was passed to embed_documents.""" + return [float(1.0)] * 9 + [float(0.0)] + + async def aembed_query(self, text: str) -> List[float]: + return self.embed_query(text) + + +class ConsistentFakeEmbeddings(FakeEmbeddings): + """Fake embeddings which remember all the texts seen so far to return consistent + vectors for the same texts.""" + + def __init__(self, dimensionality: int = 10) -> None: + self.known_texts: List[str] = [] + self.dimensionality = dimensionality + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Return consistent embeddings for each text seen so far.""" + out_vectors = [] + for text in texts: + if text not in self.known_texts: + self.known_texts.append(text) + vector = [float(1.0)] * (self.dimensionality - 1) + [ + float(self.known_texts.index(text)) + ] + out_vectors.append(vector) + return out_vectors + + def embed_query(self, text: str) -> List[float]: + """Return consistent embeddings for the text, if seen before, or a constant + one if the text is unknown.""" + return self.embed_documents([text])[0] + + +class AngularTwoDimensionalEmbeddings(Embeddings): + """ + From angles (as strings in units of pi) to unit embedding vectors on a circle. + """ + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Make a list of texts into a list of embedding vectors. + """ + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + """ + Convert input text to a 'vector' (list of floats). + If the text is a number, use it as the angle for the + unit vector in units of pi. + Any other input text becomes the singular result [0, 0] ! + """ + try: + angle = float(text) + return [math.cos(angle * math.pi), math.sin(angle * math.pi)] + except ValueError: + # Assume: just test string, no attention is paid to values. + return [0.0, 0.0] diff --git a/tests/integration_tests/cache/test_semantic_cache.py b/tests/integration_tests/cache/test_semantic_cache.py new file mode 100644 index 0000000..2923fa1 --- /dev/null +++ b/tests/integration_tests/cache/test_semantic_cache.py @@ -0,0 +1,185 @@ +""" +Test semantic cache. +Derived from SingleStoreDB. + +Source: https://github.com/langchain-ai/langchain/blob/langchain-core%3D%3D0.3.28/libs/community/tests/integration_tests/cache/test_singlestoredb_cache.py +""" + +import typing as t +import uuid + +import pytest +import sqlalchemy as sa +from langchain_core.embeddings import Embeddings +from langchain_core.globals import get_llm_cache, set_llm_cache +from langchain_core.language_models.fake_chat_models import FakeChatModel +from langchain_core.load import dumps +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, Generation, LLMResult + +from langchain_cratedb import CrateDBSemanticCache +from tests.integration_tests.cache.fake_embeddings import ( + ConsistentFakeEmbeddings, + FakeEmbeddings, +) +from tests.utils import FakeLLM + + +def random_string() -> str: + return str(uuid.uuid4()) + + +@pytest.fixture(autouse=True) +def set_cache_and_teardown() -> t.Generator[None, None, None]: + yield + set_llm_cache(None) + + +def test_semantic_cache_single(engine: sa.Engine) -> None: + """ + Test semantic cache functionality with single item. + Derived from OpenSearch/SingleStore. + """ + set_llm_cache( + CrateDBSemanticCache( + embedding=FakeEmbeddings(), + connection=engine, + search_threshold=1.0, + ) + ) + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) + cache_output = get_llm_cache().lookup("bar", llm_string) + assert cache_output == [Generation(text="fizz")] + + get_llm_cache().clear(llm_string=llm_string) + output = get_llm_cache().lookup("bar", llm_string) + assert output != [Generation(text="fizz")] + + +def test_semantic_cache_multi(engine: sa.Engine) -> None: + """ + Test semantic cache functionality with multiple items. + Derived from OpenSearch/SingleStore. + """ + set_llm_cache( + CrateDBSemanticCache( + embedding=FakeEmbeddings(), + connection=engine, + search_threshold=1.0, + ) + ) + + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + get_llm_cache().update( + "foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")] + ) + + # foo and bar will have the same embedding produced by FakeEmbeddings + cache_output = get_llm_cache().lookup("bar", llm_string) + assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")] + + # clear the cache + get_llm_cache().clear(llm_string=llm_string) + output = get_llm_cache().lookup("bar", llm_string) + assert output != [Generation(text="fizz"), Generation(text="Buzz")] + + +def test_semantic_cache_chat(engine: sa.Engine) -> None: + """ + Test semantic cache functionality for chat messages. + Derived from Redis. + """ + set_llm_cache( + CrateDBSemanticCache( + embedding=FakeEmbeddings(), + connection=engine, + search_threshold=1.0, + ) + ) + llm = FakeChatModel() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + prompt: t.List[BaseMessage] = [HumanMessage(content="foo")] + llm_cache = t.cast(CrateDBSemanticCache, get_llm_cache()) + llm_cache.update( + dumps(prompt), llm_string, [ChatGeneration(message=AIMessage(content="fizz"))] + ) + output = llm.generate([prompt]) + expected_output = LLMResult( + generations=[[ChatGeneration(message=AIMessage(content="fizz"))]], + llm_output={}, + ) + assert output == expected_output + llm_cache.clear(llm_string=llm_string) + + +@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()]) +@pytest.mark.parametrize( + "prompts, generations", + [ + # Single prompt, single generation + ([random_string()], [[random_string()]]), + # Single prompt, multiple generations + ([random_string()], [[random_string(), random_string()]]), + # Single prompt, multiple generations + ([random_string()], [[random_string(), random_string(), random_string()]]), + # Multiple prompts, multiple generations + # ( + # [random_string(), random_string()], + # [[random_string()], [random_string(), random_string()]], + # ), + ], + ids=[ + "single_prompt_single_generation", + "single_prompt_multiple_generations", + "single_prompt_multiple_generations", + # "multiple_prompts_multiple_generations", + ], +) +def test_semantic_cache_hit( + embedding: Embeddings, + prompts: t.List[str], + generations: t.List[t.List[str]], + engine: sa.Engine, +) -> None: + """ + Test semantic cache functionality with hits. + Derived from Redis. + """ + set_llm_cache( + CrateDBSemanticCache( + embedding=FakeEmbeddings(), + connection=engine, + search_threshold=1.0, + ) + ) + + llm = FakeLLM() + params = llm.dict() + params["stop"] = None + llm_string = str(sorted([(k, v) for k, v in params.items()])) + + llm_generations = [ + [ + Generation(text=generation, generation_info=params) + for generation in prompt_i_generations + ] + for prompt_i_generations in generations + ] + llm_cache = t.cast(CrateDBSemanticCache, get_llm_cache()) + for prompt_i, llm_generations_i in zip(prompts, llm_generations): + print(prompt_i) # noqa: T201 + print(llm_generations_i) # noqa: T201 + llm_cache.update(prompt_i, llm_string, llm_generations_i) + llm.generate(prompts) + assert llm.generate(prompts) == LLMResult( + generations=llm_generations, llm_output={} + )