-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
990 additions
and
991 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,72 @@ | ||
import hashlib | ||
from contextlib import contextmanager | ||
from typing import Sequence | ||
from typing import Generator, Sequence, cast | ||
|
||
from sqlalchemy import Connection, Table | ||
from sqlmodel import create_engine, Session, SQLModel | ||
from sqlalchemy import Connection | ||
from sqlmodel import Session, SQLModel, create_engine | ||
|
||
from prediction_market_agent_tooling.config import APIKeys | ||
from prediction_market_agent_tooling.tools.caches.serializers import ( | ||
json_serializer, | ||
json_deserializer, | ||
json_serializer, | ||
) | ||
from prediction_market_agent_tooling.tools.singleton import SingletonMeta | ||
|
||
|
||
class DBManager(metaclass=SingletonMeta): | ||
def __init__(self, api_keys: APIKeys) -> None: | ||
# We pass in serializers as used by db_cache (no reason not to). | ||
class DBManager: | ||
_instances: dict[str, "DBManager"] = {} | ||
|
||
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager": | ||
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url | ||
secret_value = sqlalchemy_db_url.get_secret_value() | ||
url_hash = hashlib.md5(secret_value.encode()).hexdigest() | ||
if url_hash not in cls._instances: | ||
instance = super(DBManager, cls).__new__(cls) | ||
cls._instances[url_hash] = instance | ||
return cls._instances[url_hash] | ||
|
||
def __init__(self, api_keys: APIKeys | None = None) -> None: | ||
if hasattr(self, "_engine"): | ||
return | ||
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url | ||
self._engine = create_engine( | ||
api_keys.sqlalchemy_db_url.get_secret_value(), | ||
sqlalchemy_db_url.get_secret_value(), | ||
json_serializer=json_serializer, | ||
json_deserializer=json_deserializer, | ||
pool_size=20, | ||
pool_recycle=3600, | ||
echo=True | ||
echo=True, | ||
) | ||
self.cache_table_initialized = False | ||
self.cache_table_initialized: dict[str, bool] = {} | ||
|
||
@contextmanager | ||
def get_session(self) -> Session: | ||
session = Session(self._engine) | ||
try: | ||
def get_session(self) -> Generator[Session, None, None]: | ||
with Session(self._engine) as session: | ||
yield session | ||
session.commit() | ||
except Exception: | ||
session.rollback() # Rollback if there's an error | ||
raise # Propagate the exception | ||
finally: | ||
session.close() | ||
|
||
@contextmanager | ||
def get_connection(self) -> Connection: | ||
def get_connection(self) -> Generator[Connection, None, None]: | ||
with self.get_session() as session: | ||
yield session.connection() | ||
|
||
def init_cache_metadata(self, tables: Sequence[Table]|None=None) -> None: | ||
if not self.cache_table_initialized: | ||
with self.get_connection() as conn: | ||
SQLModel.metadata.create_all(conn, tables=tables) | ||
def create_tables( | ||
self, sqlmodel_tables: Sequence[type[SQLModel]] | None = None | ||
) -> None: | ||
tables_to_create = ( | ||
[ | ||
table | ||
for sqlmodel_table in sqlmodel_tables | ||
if not self.cache_table_initialized.get( | ||
( | ||
table := SQLModel.metadata.tables[ | ||
cast(str, sqlmodel_table.__tablename__) | ||
] | ||
).name | ||
) | ||
] | ||
if sqlmodel_tables is not None | ||
else None | ||
) | ||
SQLModel.metadata.create_all(self._engine, tables=tables_to_create) | ||
|
||
self.cache_table_initialized = True | ||
for table in tables_to_create or []: | ||
self.cache_table_initialized[table.name] = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,23 @@ | ||
from typing import Generator | ||
|
||
import psycopg | ||
import pytest | ||
from pydantic.types import SecretStr | ||
from pytest_postgresql.executor import PostgreSQLExecutor | ||
from pytest_postgresql.janitor import DatabaseJanitor | ||
|
||
from prediction_market_agent_tooling.config import APIKeys | ||
|
||
|
||
@pytest.fixture | ||
def keys_with_sqlalchemy_db_url( | ||
postgresql: psycopg.Connection, | ||
@pytest.fixture(scope="session") | ||
def session_keys_with_postgresql_proc_and_enabled_cache( | ||
postgresql_proc: PostgreSQLExecutor, | ||
) -> Generator[APIKeys, None, None]: | ||
sqlalchemy_db_url = f"postgresql+psycopg2://{postgresql.info.user}:@{postgresql.info.host}:{postgresql.info.port}/{postgresql.info.dbname}" | ||
yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=True) | ||
with DatabaseJanitor( | ||
user=postgresql_proc.user, | ||
host=postgresql_proc.host, | ||
port=postgresql_proc.port, | ||
dbname=postgresql_proc.dbname, | ||
version=postgresql_proc.version, | ||
): | ||
sqlalchemy_db_url = f"postgresql+psycopg2://{postgresql_proc.user}:@{postgresql_proc.host}:{postgresql_proc.port}/{postgresql_proc.dbname}" | ||
yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,35 @@ | ||
import typing as t | ||
from unittest.mock import patch | ||
import tempfile | ||
|
||
import pytest | ||
from pydantic import SecretStr | ||
from sqlmodel import SQLModel | ||
|
||
from prediction_market_agent_tooling.config import APIKeys | ||
from prediction_market_agent_tooling.tools.db.db_manager import DBManager | ||
from prediction_market_agent_tooling.tools.db.db_manager import APIKeys, DBManager | ||
|
||
|
||
def test_DBManager_creates_only_one_instance() -> None: | ||
with tempfile.NamedTemporaryFile( | ||
suffix=".db" | ||
) as temp_db1, tempfile.NamedTemporaryFile(suffix=".db") as temp_db3: | ||
db1 = DBManager( | ||
APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}")) | ||
) | ||
db2 = DBManager( | ||
APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}")) | ||
) | ||
db3 = DBManager( | ||
APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db3.name}")) | ||
) | ||
are_same_instance = db1 is db2 | ||
are_not_same_instance = db1 is not db3 | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def random_api_keys_with_sqlalchemy_db_url() -> t.Generator[APIKeys,None,None]: | ||
with patch("prediction_market_agent_tooling.config.APIKeys.sqlalchemy_db_url", SecretStr("abc")): | ||
keys = APIKeys() | ||
yield keys | ||
|
||
def test_DBManager_creates_only_one_instance(random_api_keys_with_sqlalchemy_db_url: APIKeys) -> None: | ||
|
||
db1 = DBManager(random_api_keys_with_sqlalchemy_db_url) | ||
db2 = DBManager(random_api_keys_with_sqlalchemy_db_url) | ||
are_same_instance = db1 is db2 | ||
assert are_same_instance, "DBManager created more than one instance!" | ||
assert ( | ||
are_not_same_instance | ||
), "DBManager returned isntance with a different SQLALCHEMY_DB_URL!" | ||
|
||
|
||
def test_session_can_be_used_by_metadata() -> None: | ||
db = DBManager() | ||
session = db.get_session() | ||
# ToDo - add more tests | ||
SQLModel.metadata.create_all(session.bind) | ||
# TODO: Gabriel what was the goal in this test? | ||
# def test_session_can_be_used_by_metadata() -> None: | ||
# db = DBManager() | ||
# session = db.get_session() | ||
# # ToDo - add more tests | ||
# SQLModel.metadata.create_all(session.bind) |
Oops, something went wrong.