Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii committed Nov 21, 2024
1 parent a4497a9 commit 80c805b
Show file tree
Hide file tree
Showing 9 changed files with 990 additions and 991 deletions.
1,701 changes: 832 additions & 869 deletions poetry.lock

Large diffs are not rendered by default.

52 changes: 25 additions & 27 deletions prediction_market_agent_tooling/tools/caches/db_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
if not api_keys.ENABLE_CACHE:
return func(*args, **kwargs)

with DBManager(api_keys).get_connection() as conn:
# Create table if it doesn't exist
SQLModel.metadata.create_all(conn, tables=[SQLModel.metadata.tables[FunctionCache.__tablename__]],)
DBManager(api_keys).create_tables([FunctionCache])

# Convert *args and **kwargs to a single dictionary, where we have names for arguments passed as args as well.
signature = inspect.signature(func)
Expand Down Expand Up @@ -141,10 +139,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:

# Determine if the function returns or contains Pydantic BaseModel(s)
return_type = func.__annotations__.get("return", None)
is_pydantic_model = False

if return_type is not None and contains_pydantic_model(return_type):
is_pydantic_model = True
is_pydantic_model = return_type is not None and contains_pydantic_model(
return_type
)

with DBManager(api_keys).get_session() as session:
# Try to get cached result
Expand All @@ -161,26 +158,26 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
cutoff_time = utcnow() - max_age
statement = statement.where(FunctionCache.created_at >= cutoff_time)
cached_result = session.exec(statement).first()
# We indent here to keep the session open, so that cached_result doesn't go out of scope.
if cached_result:
logger.info(
# Keep the special [case-hit] identifier so we can easily track it in GCP.
f"[cache-hit] Cache hit for {full_function_name} with args {args_dict} and output {cached_result.result}"
)
if is_pydantic_model:
# If the output contains any Pydantic models, we need to initialise them.
try:
return convert_cached_output_to_pydantic(
return_type, cached_result.result
)
except ValueError as e:
# In case of backward-incompatible pydantic model, just treat it as cache miss, to not error out.
logger.warning(
f"Can not validate {cached_result=} into {return_type=} because {e=}, treating as cache miss."
)
cached_result = None
else:
return cached_result.result

if cached_result:
logger.info(
# Keep the special [case-hit] identifier so we can easily track it in GCP.
f"[cache-hit] Cache hit for {full_function_name} with args {args_dict} and output {cached_result.result}"
)
if is_pydantic_model:
# If the output contains any Pydantic models, we need to initialise them.
try:
return convert_cached_output_to_pydantic(
return_type, cached_result.result
)
except ValueError as e:
# In case of backward-incompatible pydantic model, just treat it as cache miss, to not error out.
logger.warning(
f"Can not validate {cached_result=} into {return_type=} because {e=}, treating as cache miss."
)
cached_result = None
else:
return cached_result.result

# On cache miss, compute the result
computed_result = func(*args, **kwargs)
Expand All @@ -202,6 +199,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
with DBManager(api_keys).get_session() as session:
logger.info(f"Saving {cache_entry} into database.")
session.add(cache_entry)
session.commit()

return computed_result

Expand Down
74 changes: 48 additions & 26 deletions prediction_market_agent_tooling/tools/db/db_manager.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,72 @@
import hashlib
from contextlib import contextmanager
from typing import Sequence
from typing import Generator, Sequence, cast

from sqlalchemy import Connection, Table
from sqlmodel import create_engine, Session, SQLModel
from sqlalchemy import Connection
from sqlmodel import Session, SQLModel, create_engine

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.tools.caches.serializers import (
json_serializer,
json_deserializer,
json_serializer,
)
from prediction_market_agent_tooling.tools.singleton import SingletonMeta


class DBManager(metaclass=SingletonMeta):
def __init__(self, api_keys: APIKeys) -> None:
# We pass in serializers as used by db_cache (no reason not to).
class DBManager:
_instances: dict[str, "DBManager"] = {}

def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager":
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url
secret_value = sqlalchemy_db_url.get_secret_value()
url_hash = hashlib.md5(secret_value.encode()).hexdigest()
if url_hash not in cls._instances:
instance = super(DBManager, cls).__new__(cls)
cls._instances[url_hash] = instance
return cls._instances[url_hash]

def __init__(self, api_keys: APIKeys | None = None) -> None:
if hasattr(self, "_engine"):
return
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url
self._engine = create_engine(
api_keys.sqlalchemy_db_url.get_secret_value(),
sqlalchemy_db_url.get_secret_value(),
json_serializer=json_serializer,
json_deserializer=json_deserializer,
pool_size=20,
pool_recycle=3600,
echo=True
echo=True,
)
self.cache_table_initialized = False
self.cache_table_initialized: dict[str, bool] = {}

@contextmanager
def get_session(self) -> Session:
session = Session(self._engine)
try:
def get_session(self) -> Generator[Session, None, None]:
with Session(self._engine) as session:
yield session
session.commit()
except Exception:
session.rollback() # Rollback if there's an error
raise # Propagate the exception
finally:
session.close()

@contextmanager
def get_connection(self) -> Connection:
def get_connection(self) -> Generator[Connection, None, None]:
with self.get_session() as session:
yield session.connection()

def init_cache_metadata(self, tables: Sequence[Table]|None=None) -> None:
if not self.cache_table_initialized:
with self.get_connection() as conn:
SQLModel.metadata.create_all(conn, tables=tables)
def create_tables(
self, sqlmodel_tables: Sequence[type[SQLModel]] | None = None
) -> None:
tables_to_create = (
[
table
for sqlmodel_table in sqlmodel_tables
if not self.cache_table_initialized.get(
(
table := SQLModel.metadata.tables[
cast(str, sqlmodel_table.__tablename__)
]
).name
)
]
if sqlmodel_tables is not None
else None
)
SQLModel.metadata.create_all(self._engine, tables=tables_to_create)

self.cache_table_initialized = True
for table in tables_to_create or []:
self.cache_table_initialized[table.name] = True
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,15 @@ class RelevantNewsCacheModel(SQLModel, table=True):


class RelevantNewsResponseCache:
def __init__(self, api_keys: APIKeys|None = None):
self.db_manager = DBManager(api_keys if api_keys else APIKeys())
def __init__(self, api_keys: APIKeys | None = None):
self.db_manager = DBManager(api_keys)
self._initialize_db()

def _initialize_db(self) -> None:
"""
Creates the tables if they don't exist
"""
with self.db_manager.get_connection() as conn:
SQLModel.metadata.create_all(
conn,
tables=[SQLModel.metadata.tables[RelevantNewsCacheModel.__tablename__]],
)
self.db_manager.create_tables([RelevantNewsCacheModel])

def find(
self,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ python-dateutil = "^2.9.0.post0"
types-python-dateutil = "^2.9.0.20240906"
pinatapy-vourhey = "^0.2.0"
hishel = "^0.0.31"
pytest-postgresql = "^6.1.1"

[tool.poetry.extras]
openai = ["openai"]
Expand All @@ -65,6 +64,7 @@ mypy = "^1.11.1"
black = "^23.12.1"
ape-foundry = "^0.8.2"
eth-ape = "^0.8.10,<0.8.17" # 0.8.17 doesn't work with the current configuration and needs a fix, see https://github.com/gnosis/prediction-market-agent-tooling/issues/518.
pytest-postgresql = "^6.1.1"

[build-system]
requires = ["poetry-core"]
Expand Down
20 changes: 14 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from typing import Generator

import psycopg
import pytest
from pydantic.types import SecretStr
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.janitor import DatabaseJanitor

from prediction_market_agent_tooling.config import APIKeys


@pytest.fixture
def keys_with_sqlalchemy_db_url(
postgresql: psycopg.Connection,
@pytest.fixture(scope="session")
def session_keys_with_postgresql_proc_and_enabled_cache(
postgresql_proc: PostgreSQLExecutor,
) -> Generator[APIKeys, None, None]:
sqlalchemy_db_url = f"postgresql+psycopg2://{postgresql.info.user}:@{postgresql.info.host}:{postgresql.info.port}/{postgresql.info.dbname}"
yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=True)
with DatabaseJanitor(
user=postgresql_proc.user,
host=postgresql_proc.host,
port=postgresql_proc.port,
dbname=postgresql_proc.dbname,
version=postgresql_proc.version,
):
sqlalchemy_db_url = f"postgresql+psycopg2://{postgresql_proc.user}:@{postgresql_proc.host}:{postgresql_proc.port}/{postgresql_proc.dbname}"
yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=True)
51 changes: 28 additions & 23 deletions tests/tools/db/test_db_manager.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
import typing as t
from unittest.mock import patch
import tempfile

import pytest
from pydantic import SecretStr
from sqlmodel import SQLModel

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.tools.db.db_manager import DBManager
from prediction_market_agent_tooling.tools.db.db_manager import APIKeys, DBManager


def test_DBManager_creates_only_one_instance() -> None:
with tempfile.NamedTemporaryFile(
suffix=".db"
) as temp_db1, tempfile.NamedTemporaryFile(suffix=".db") as temp_db3:
db1 = DBManager(
APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}"))
)
db2 = DBManager(
APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}"))
)
db3 = DBManager(
APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db3.name}"))
)
are_same_instance = db1 is db2
are_not_same_instance = db1 is not db3


@pytest.fixture(scope="module")
def random_api_keys_with_sqlalchemy_db_url() -> t.Generator[APIKeys,None,None]:
with patch("prediction_market_agent_tooling.config.APIKeys.sqlalchemy_db_url", SecretStr("abc")):
keys = APIKeys()
yield keys

def test_DBManager_creates_only_one_instance(random_api_keys_with_sqlalchemy_db_url: APIKeys) -> None:

db1 = DBManager(random_api_keys_with_sqlalchemy_db_url)
db2 = DBManager(random_api_keys_with_sqlalchemy_db_url)
are_same_instance = db1 is db2
assert are_same_instance, "DBManager created more than one instance!"
assert (
are_not_same_instance
), "DBManager returned isntance with a different SQLALCHEMY_DB_URL!"


def test_session_can_be_used_by_metadata() -> None:
db = DBManager()
session = db.get_session()
# ToDo - add more tests
SQLModel.metadata.create_all(session.bind)
# TODO: Gabriel what was the goal in this test?
# def test_session_can_be_used_by_metadata() -> None:
# db = DBManager()
# session = db.get_session()
# # ToDo - add more tests
# SQLModel.metadata.create_all(session.bind)
Loading

0 comments on commit 80c805b

Please sign in to comment.