Skip to content

Commit

Permalink
Cache: Add CrateDBCache based on SQLAlchemyCache
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Dec 23, 2024
1 parent 5bbb5bb commit c41a9be
Show file tree
Hide file tree
Showing 8 changed files with 431 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


## Unreleased
- Added implementation and software tests for `CrateDBCache`,
deriving from `SQLAlchemyCache`.

## v0.0.0 - 2024-12-16
- Make it work
Expand Down
2 changes: 2 additions & 0 deletions langchain_cratedb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

patch_sqlalchemy_dialect()

from langchain_cratedb.cache import CrateDBCache
from langchain_cratedb.chat_history import CrateDBChatMessageHistory
from langchain_cratedb.loaders import CrateDBLoader
from langchain_cratedb.vectorstores import (
Expand All @@ -20,6 +21,7 @@
del metadata # optional, avoids polluting the results of dir(__package__)

__all__ = [
"CrateDBCache",
"CrateDBChatMessageHistory",
"CrateDBLoader",
"CrateDBVectorStore",
Expand Down
18 changes: 18 additions & 0 deletions langchain_cratedb/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import typing as t

import sqlalchemy as sa
from langchain_community.cache import FullLLMCache, SQLAlchemyCache
from sqlalchemy_cratedb.support import refresh_after_dml


class CrateDBCache(SQLAlchemyCache):
"""
CrateDB adapter for LangChain standard / full cache subsystem.
It builds upon SQLAlchemyCache 1:1.
"""

def __init__(
self, engine: sa.Engine, cache_schema: t.Type[FullLLMCache] = FullLLMCache
):
refresh_after_dml(engine)
super().__init__(engine, cache_schema)
18 changes: 18 additions & 0 deletions langchain_cratedb/patches.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
import warnings

from sqlalchemy_cratedb.compiler import CrateDDLCompiler


def ddl_compiler_visit_create_index(self, create, **kw) -> str: # type: ignore[no-untyped-def]
"""
CrateDB does not support `CREATE INDEX` statements.
"""
warnings.warn(
"CrateDB does not support `CREATE INDEX` statements, "
"they will be omitted when generating DDL statements.",
stacklevel=2,
)
return "SELECT 1"


def patch_sqlalchemy_dialect() -> None:
"""
Fixes `AttributeError: 'CrateCompilerSA20' object has no attribute 'visit_on_conflict_do_update'`
Expand All @@ -10,6 +27,7 @@ def patch_sqlalchemy_dialect() -> None:

CrateCompiler.visit_on_conflict_do_update = PGCompiler.visit_on_conflict_do_update
CrateCompiler._on_conflict_target = PGCompiler._on_conflict_target
CrateDDLCompiler.visit_create_index = ddl_compiler_visit_create_index


patch_sqlalchemy_dialect()
Empty file.
111 changes: 111 additions & 0 deletions tests/integration_tests/cache/test_standard_mcsa_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Test standard and semantic caching.
Derived from Memcached and SQLAlchemy.
Source: https://github.com/langchain-ai/langchain/blob/langchain-core%3D%3D0.3.28/libs/community/tests/integration_tests/cache/test_memcached_cache.py
"""

import pytest
import sqlalchemy as sa
from langchain.globals import get_llm_cache, set_llm_cache
from langchain_core.caches import BaseCache
from langchain_core.outputs import Generation, LLMResult

from langchain_cratedb import CrateDBCache
from tests.utils import FakeLLM


@pytest.fixture()
def cache(engine: sa.Engine) -> BaseCache:
return CrateDBCache(engine=engine)


def test_memcached_cache(cache: BaseCache) -> None:
"""Test general caching"""

set_llm_cache(cache)
llm = FakeLLM()

params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
get_llm_cache().clear()


def test_memcached_cache_flush(cache: BaseCache) -> None:
"""Test flushing cache"""

set_llm_cache(cache)
llm = FakeLLM()

params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
get_llm_cache().clear(delay=0, noreply=False)

# After cache has been cleared, the result shouldn't be the same
output = llm.generate(["foo"])
assert output != expected_output


def test_sqlalchemy_cache(engine: sa.Engine) -> None:
"""Test custom_caching behavior."""

from sqlalchemy_cratedb.support import patch_autoincrement_timestamp

patch_autoincrement_timestamp()

Base = sa.orm.declarative_base()

class FulltextLLMCache(Base): # type: ignore
"""CrateDB table for fulltext-indexed LLM Cache."""

__tablename__ = "llm_cache_fulltext"
# TODO: Original. Can it be converged by adding a polyfill to
# `sqlalchemy-cratedb`?
# id = Column(Integer, Sequence("cache_id"), primary_key=True)
id = sa.Column(sa.BigInteger, server_default=sa.func.now(), primary_key=True)
prompt = sa.Column(sa.String, nullable=False)
llm = sa.Column(sa.String, nullable=False)
idx = sa.Column(sa.Integer)
response = sa.Column(sa.String)

set_llm_cache(CrateDBCache(engine, FulltextLLMCache))
get_llm_cache().clear()

llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == expected_cache_output
set_llm_cache(None)
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
generations=expected_generations,
llm_output=None,
)
assert output == expected_output
Loading

0 comments on commit c41a9be

Please sign in to comment.