Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First attempt at DB_Manager #555

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

gabrielfior
Copy link
Contributor

No description provided.

@gabrielfior gabrielfior linked an issue Nov 13, 2024 that may be closed by this pull request
Copy link

coderabbitai bot commented Nov 13, 2024

Walkthrough

The pull request introduces multiple changes across several files, primarily focusing on enhancing the handling of database connections and serialization processes. Key modifications include the refactoring of the APIKeys class to improve instance method calls, the introduction of a DBManager class for centralized database management, and updates to caching mechanisms. Additionally, new serialization functions have been added, and existing tests have been updated to reflect these changes, ensuring consistent usage of parameters across the codebase.

Changes

File Change Summary
prediction_market_agent_tooling/config.py Modified APIKeys class methods to use self.model_fields for instance context; updated check_if_is_safe_owner method to validate SAFE_ADDRESS.
prediction_market_agent_tooling/markets/omen/data_models.py Updated from_tuple method signature to accept a variable-length tuple; simplified instance creation logic.
prediction_market_agent_tooling/tools/caches/db_cache.py Refactored db_cache decorator to utilize DBManager for session management and table creation; removed custom JSON serialization logic.
prediction_market_agent_tooling/tools/caches/serializers.py Introduced new file with functions for JSON serialization and deserialization of custom types.
prediction_market_agent_tooling/tools/db/db_manager.py Added DBManager class for managing database connections and sessions, including methods for session handling and table creation.
prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py Updated RelevantNewsResponseCache to use DBManager for database interactions; changed constructor to accept APIKeys.
tests/conftest.py Replaced database connection fixture with one using PostgreSQLExecutor and improved lifecycle management.
tests/tools/db/test_db_manager.py Added test to verify singleton behavior of DBManager.
tests/tools/test_db_cache.py Updated test functions to use new database connection fixture.
tests_integration/tools/test_relevant_news_analysis.py Modified test to utilize APIKeys for database URL management, enhancing security handling.

Possibly related PRs

Suggested reviewers

  • evangriffiths

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between e18d477 and fbbd617.

📒 Files selected for processing (2)
  • prediction_market_agent_tooling/tools/db/db_manager.py (1 hunks)
  • tests/tools/db/test_db_manager.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • prediction_market_agent_tooling/tools/db/db_manager.py
  • tests/tools/db/test_db_manager.py

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

prediction_market_agent_tooling/tools/caches/db_cache.py Outdated Show resolved Hide resolved
prediction_market_agent_tooling/tools/caches/db_cache.py Outdated Show resolved Hide resolved

@contextmanager
def get_session(self) -> Session:
session = Session(self._engine)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't

@contextmanager
def get_session(self) -> Session:
    with Session(self._engine) as session:
        yield session

have the same result, without doing all the stuff around it manually?

@gabrielfior
Copy link
Contributor Author

Closing this since this has less priority than the upcoming general agent milestone - also, this PR apparently solved the most pressing issues.

@kongzii kongzii reopened this Nov 21, 2024
@kongzii kongzii marked this pull request as ready for review November 21, 2024 10:49
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Outside diff range and nitpick comments (16)
tests/conftest.py (1)

22-23: Consider making cache enablement configurable

The cache is currently hardcoded to be enabled. Consider making this configurable through a fixture parameter to allow testing both cached and non-cached scenarios.

 @pytest.fixture(scope="session")
 def session_keys_with_postgresql_proc_and_enabled_cache(
     postgresql_proc: PostgreSQLExecutor,
+    enable_cache: bool = True,
 ) -> Generator[APIKeys, None, None]:
     with DatabaseJanitor(
         user=postgresql_proc.user,
         host=postgresql_proc.host,
         port=postgresql_proc.port,
         dbname=postgresql_proc.dbname,
         version=postgresql_proc.version,
     ):
         sqlalchemy_db_url = f"postgresql+psycopg2://{postgresql_proc.user}:@{postgresql_proc.host}:{postgresql_proc.port}/{postgresql_proc.dbname}"
-        yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=True)
+        yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=enable_cache)
tests/tools/db/test_db_manager.py (3)

27-27: Fix typo in assertion message

There's a typo in the assertion message: "isntance" should be "instance".

-    ), "DBManager returned isntance with a different SQLALCHEMY_DB_URL!"
+    ), "DBManager returned instance with a different SQLALCHEMY_DB_URL!"

8-28: Consider enhancing singleton pattern test coverage

While the test correctly verifies basic singleton behavior, consider adding these scenarios:

  1. Test thread safety of the singleton pattern
  2. Verify singleton state persistence
  3. Test edge cases with invalid DB URLs

Here's a suggested enhancement:

def test_DBManager_creates_only_one_instance() -> None:
    with tempfile.NamedTemporaryFile(suffix=".db") as temp_db1, \
         tempfile.NamedTemporaryFile(suffix=".db") as temp_db2:
        
        # Test basic singleton behavior
        db1 = DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}")))
        db2 = DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}")))
        db3 = DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db2.name}")))
        
        # Verify instance equality
        assert db1 is db2, "DBManager created more than one instance!"
        assert db1 is not db3, "DBManager returned instance with a different SQLALCHEMY_DB_URL!"
        
        # Verify state persistence
        db1.some_state = "test"
        assert hasattr(db2, "some_state"), "Singleton state not preserved!"
        assert db2.some_state == "test", "Singleton state not shared!"
        
        # Test invalid URL handling
        with pytest.raises(ValueError):
            DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr("invalid://url")))

30-35: Consider removing or completing the commented test before closure

Since this PR is marked for closure, it would be better to either:

  1. Remove the commented-out test completely, or
  2. Complete the test implementation if it's testing critical functionality

The current TODO and incomplete state adds unnecessary noise to the codebase.

prediction_market_agent_tooling/tools/caches/serializers.py (2)

10-11: Add error handling for json.dumps.

While the function is well-structured, it should handle potential JSON serialization errors to provide better error messages and prevent crashes.

 def json_serializer(x: t.Any) -> str:
-    return json.dumps(x, default=json_serializer_default_fn)
+    try:
+        return json.dumps(x, default=json_serializer_default_fn)
+    except TypeError as e:
+        raise TypeError(f"Failed to serialize object: {e}") from e

14-32: Improve error message specificity.

The error message could be more helpful by including the actual type of the unsupported value.

     elif isinstance(y, BaseModel):
         return y.model_dump()
     raise TypeError(
-        f"Unsuported type for the default json serialize function, value is {y}."
+        f"Unsupported type '{type(y).__name__}' for JSON serialization. Value: {y}"
     )
tests_integration/tools/test_relevant_news_analysis.py (1)

67-69: Consider extracting the test database URL to a fixture.

While the current implementation using SQLite in-memory database is appropriate for testing, consider moving the database configuration to a pytest fixture for better reusability across test files.

Example implementation:

@pytest.fixture
def test_db_config():
    return APIKeys(SQLALCHEMY_DB_URL=SecretStr("sqlite:///:memory:"))

def test_get_certified_relevant_news_since_cached(test_db_config) -> None:
    cache = RelevantNewsResponseCache(test_db_config)
prediction_market_agent_tooling/config.py (1)

Line range hint 199-205: Consider caching the ownership check result

While the validation and ownership check are correct, consider these improvements:

  1. Cache the ownership check result as it's unlikely to change frequently and the SafeV141 instantiation could be expensive.
  2. Add logging for security audit purposes when ownership checks fail.

Example implementation:

from functools import lru_cache
import logging

logger = logging.getLogger(__name__)

def check_if_is_safe_owner(self, ethereum_client: EthereumClient) -> bool:
    if not self.SAFE_ADDRESS:
        raise ValueError("Cannot check ownership if safe_address is not defined.")

    @lru_cache(maxsize=1)
    def _check_ownership(safe_address: str, public_key: str) -> bool:
        s = SafeV141(safe_address, ethereum_client)
        is_owner = s.retrieve_is_owner(public_key)
        if not is_owner:
            logger.warning(f"Ownership check failed for address {public_key}")
        return is_owner

    public_key_from_signer = private_key_to_public_key(self.bet_from_private_key)
    return _check_ownership(self.SAFE_ADDRESS, public_key_from_signer)
tests/tools/test_db_cache.py (2)

Line range hint 258-304: Refactor cache invalidation test to improve clarity

The current implementation uses class redefinition and multiple type ignore comments, which makes the test harder to maintain and understand.

Consider refactoring to use two separate, clearly named classes:

-    # Initial output model
-    class FirstOutputModel(TestOutputModel):
-        pass
+    class OriginalOutputModel(TestOutputModel):
+        pass
+
+    class ModifiedOutputModel(TestOutputModel):
+        new_field: str

-    @db_cache(api_keys=session_keys_with_postgresql_proc_and_enabled_cache)
-    def multiply_models(a: TestInputModel, b: TestInputModel) -> FirstOutputModel:
+    @db_cache(api_keys=session_keys_with_postgresql_proc_and_enabled_cache)
+    def multiply_models(a: TestInputModel, b: TestInputModel) -> OriginalOutputModel | ModifiedOutputModel:

This approach:

  • Eliminates the need for type ignore comments
  • Makes the test's intention clearer
  • Better represents real-world model evolution scenarios

Line range hint 1-324: Consider adding tests for error handling and edge cases

While the test suite comprehensively covers various data types and caching scenarios, consider adding tests for:

  1. Error handling:
    • Database connection failures
    • Serialization errors for complex objects
  2. Edge cases:
    • Very large objects that might hit size limits
    • Special characters in strings that might affect serialization

Would you like me to help generate these additional test cases?

prediction_market_agent_tooling/tools/db/db_manager.py (3)

31-38: Configure echo Parameter Appropriately in create_engine

The echo=True parameter in create_engine enables SQLAlchemy to log all SQL statements, which may not be suitable for a production environment due to potential performance impacts and security concerns. It's advisable to set echo to False or make it configurable.

You can modify the code to make echo configurable:

 self._engine = create_engine(
     sqlalchemy_db_url.get_secret_value(),
     json_serializer=json_serializer,
     json_deserializer=json_deserializer,
     pool_size=20,
     pool_recycle=3600,
-    echo=True,
+    echo=False,
 )

Alternatively, you can use an environment variable or configuration setting to control this parameter.


54-69: Simplify create_tables Method for Improved Readability

The current implementation of the create_tables method is complex due to nested list comprehensions and the use of the walrus operator. This can make the code difficult to read and maintain.

Refactor the method to enhance clarity:

 def create_tables(
     self, sqlmodel_tables: Sequence[type[SQLModel]] | None = None
 ) -> None:
-    tables_to_create = (
-        [
-            table
-            for sqlmodel_table in sqlmodel_tables
-            if not self.cache_table_initialized.get(
-                (
-                    table := SQLModel.metadata.tables[
-                        cast(str, sqlmodel_table.__tablename__)
-                    ]
-                ).name
-            )
-        ]
-        if sqlmodel_tables is not None
-        else None
-    )
+    if sqlmodel_tables is not None:
+        tables_to_create = []
+        for sqlmodel_table in sqlmodel_tables:
+            table_name = cast(str, sqlmodel_table.__tablename__)
+            table = SQLModel.metadata.tables[table_name]
+            if not self.cache_table_initialized.get(table.name):
+                tables_to_create.append(table)
+    else:
+        tables_to_create = None

     SQLModel.metadata.create_all(self._engine, tables=tables_to_create)

     for table in tables_to_create or []:
         self.cache_table_initialized[table.name] = True

This refactored version reduces complexity and enhances readability by using straightforward control structures.


15-72: Add Docstrings to Class and Methods for Better Documentation

The DBManager class and its methods lack docstrings. Adding docstrings would improve code readability and help other developers understand the purpose and usage of each component.

Consider adding docstrings like the following:

class DBManager:
    """Manages database connections and sessions using SQLAlchemy and SQLModel."""

    def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager":
        """Implements a singleton pattern based on the database URL."""

    def __init__(self, api_keys: APIKeys | None = None) -> None:
        """Initializes the database engine with custom serializers."""

    @contextmanager
    def get_session(self) -> Generator[Session, None, None]:
        """Provides a session context manager."""

    @contextmanager
    def get_connection(self) -> Generator[Connection, None, None]:
        """Provides a connection context manager."""

    def create_tables(
        self, sqlmodel_tables: Sequence[type[SQLModel]] | None = None
    ) -> None:
        """Creates tables if they have not been initialized."""
prediction_market_agent_tooling/tools/caches/db_cache.py (3)

101-101: Refactor to reuse DBManager instance within the decorator

Currently, DBManager(api_keys) is instantiated multiple times within the db_cache decorator, specifically at lines 101, 146, and 199. Creating a single instance of DBManager and reusing it throughout the decorator can improve efficiency by reducing redundant object creation and clarifying the code structure.

Apply this diff to refactor the code:

 def db_cache(
     func: FunctionT | None = None,
     *,
     max_age: timedelta | None = None,
     cache_none: bool = True,
     api_keys: APIKeys | None = None,
     ignore_args: Sequence[str] | None = None,
     ignore_arg_types: Sequence[type] | None = None,
 ) -> FunctionT | Callable[[FunctionT], FunctionT]:
     if func is None:
         # Ugly Pythonic way to support this decorator as `@postgres_cache` but also `@postgres_cache(max_age=timedelta(days=3))`
         def decorator(func: FunctionT) -> FunctionT:
             return db_cache(
                 func,
                 max_age=max_age,
                 cache_none=cache_none,
                 api_keys=api_keys,
                 ignore_args=ignore_args,
                 ignore_arg_types=ignore_arg_types,
             )

         return decorator

     api_keys = api_keys if api_keys is not None else APIKeys()
+    db_manager = DBManager(api_keys)

     @wraps(func)
     def wrapper(*args: Any, **kwargs: Any) -> Any:
         # If caching is disabled, just call the function and return it
         if not api_keys.ENABLE_CACHE:
             return func(*args, **kwargs)

-        DBManager(api_keys).create_tables([FunctionCache])
+        db_manager.create_tables([FunctionCache])

         # Convert *args and **kwargs to a single dictionary...
         # [Existing code continues]

         # Database session for retrieving cached result
-        with DBManager(api_keys).get_session() as session:
+        with db_manager.get_session() as session:
             # Try to get cached result
             # [Existing code continues]

         # On cache miss, compute the result
         computed_result = func(*args, **kwargs)
         # [Logging and other logic]

         # Save the computed result to cache if applicable
         if cache_none or computed_result is not None:
             cache_entry = FunctionCache(
                 function_name=function_name,
                 full_function_name=full_function_name,
                 args_hash=args_hash,
                 args=args_dict,
                 result=computed_result,
                 created_at=utcnow(),
             )
-            with DBManager(api_keys).get_session() as session:
+            with db_manager.get_session() as session:
                 logger.info(f"Saving {cache_entry} into database.")
                 session.add(cache_entry)
                 session.commit()

         return computed_result

     return cast(FunctionT, wrapper)

Also applies to: 146-146, 199-199


142-144: Simplify the check for Pydantic models in the return type

The condition to check if the return type contains a Pydantic model can be simplified for readability. Since contains_pydantic_model already checks for None, the return_type is not None check is redundant.

Apply this diff to simplify the condition:

-    is_pydantic_model = return_type is not None and contains_pydantic_model(
-        return_type
-    )
+    is_pydantic_model = contains_pydantic_model(return_type)

146-146: Add exception handling for database session operations

While using the with statement ensures that sessions are properly closed, it's prudent to add exception handling around database operations to catch and handle potential exceptions such as connection errors or data integrity issues. This enhances the robustness of the application.

Consider wrapping database interactions in try-except blocks:

 with db_manager.get_session() as session:
+    try:
         # Database operations
+    except Exception as e:
+        logger.error(f"Database operation failed: {e}")
+        raise

Also applies to: 199-199

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 418e1dc and e18d477.

⛔ Files ignored due to path filters (2)
  • poetry.lock is excluded by !**/*.lock, !**/*.lock
  • pyproject.toml is excluded by !**/*.toml
📒 Files selected for processing (10)
  • prediction_market_agent_tooling/config.py (1 hunks)
  • prediction_market_agent_tooling/markets/omen/data_models.py (1 hunks)
  • prediction_market_agent_tooling/tools/caches/db_cache.py (5 hunks)
  • prediction_market_agent_tooling/tools/caches/serializers.py (1 hunks)
  • prediction_market_agent_tooling/tools/db/db_manager.py (1 hunks)
  • prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py (3 hunks)
  • tests/conftest.py (1 hunks)
  • tests/tools/db/test_db_manager.py (1 hunks)
  • tests/tools/test_db_cache.py (15 hunks)
  • tests_integration/tools/test_relevant_news_analysis.py (2 hunks)
🔇 Additional comments (18)
tests/conftest.py (3)

5-6: LGTM! Well-organized imports

The new imports for PostgreSQL testing utilities follow proper import ordering conventions and are appropriate for the implemented functionality.


15-21: LGTM! Proper database lifecycle management

The use of DatabaseJanitor ensures proper cleanup of database resources during testing. The parameters are correctly derived from postgresql_proc.


11-13: Verify fixture updates across test files

This is a breaking change as the fixture name has changed from keys_with_sqlalchemy_db_url to session_keys_with_postgresql_proc_and_enabled_cache. Let's verify that all test files have been updated accordingly.

✅ Verification successful

Fixture update verification completed successfully

The search results show that:

  1. There are no remaining references to the old fixture name keys_with_sqlalchemy_db_url in the tests directory
  2. The new fixture name session_keys_with_postgresql_proc_and_enabled_cache is properly used across test files, particularly in tests/tools/test_db_cache.py
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for any remaining usage of the old fixture name
echo "Checking for old fixture usage..."
rg "keys_with_sqlalchemy_db_url" tests/

# Search for usage of the new fixture name to ensure it's being used
echo "Checking new fixture usage..."
rg "session_keys_with_postgresql_proc_and_enabled_cache" tests/

Length of output: 3539

tests/tools/db/test_db_manager.py (1)

1-6: LGTM! Well-structured imports.

The imports are appropriate and follow security best practices by using SecretStr for database URLs.

prediction_market_agent_tooling/tools/caches/serializers.py (2)

1-8: LGTM! Well-organized imports.

The imports follow good practices with standard library imports first, followed by third-party and local imports. All imports are utilized in the code.


1-61: Verify bidirectional serialization for all supported types.

Let's ensure that all supported types can be correctly serialized and deserialized back to their original form.

✅ Verification successful

Serialization system is correctly implemented and used

The verification shows that the serialization system properly handles all supported types:

  • DatetimeUTC: Correctly serialized with ISO format and prefix
  • timedelta: Properly handled using total seconds
  • date: Serialized using ISO format
  • BaseModel: Serialized using model_dump()

The implementation is used consistently throughout the codebase, primarily in:

  • Database management (db_manager.py)
  • Cache implementations
  • Market data models
  • API responses
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Find all usages of the serialization functions to verify type coverage

# Look for direct usage of serialization functions
echo "Direct usage of serializers:"
rg -l "json_serializer|json_deserializer" --type python

# Look for the custom types being serialized
echo -e "\nCustom type usage:"
rg -l "DatetimeUTC|timedelta|BaseModel" --type python

# Look for potential new types that might need serialization
echo -e "\nPotential new types to consider:"
ast-grep --pattern 'class $name(BaseModel)' --type python

Length of output: 616


Script:

#!/bin/bash
# Description: Find all usages of the serialization functions to verify type coverage

# Look for direct usage of serialization functions
echo "Direct usage of serializers:"
rg "json_serializer|json_deserializer"

# Look for the custom types being serialized
echo -e "\nCustom type usage:"
rg "DatetimeUTC|timedelta|BaseModel"

# Look for potential new types that might need serialization
echo -e "\nPotential new types to consider:"
ast-grep 'class $name(BaseModel)'

Length of output: 50975

tests_integration/tools/test_relevant_news_analysis.py (2)

5-5: LGTM! Good security practice with SecretStr.

The addition of SecretStr from pydantic and APIKeys improves the handling of sensitive configuration data.

Also applies to: 7-7


67-69: Verify dependencies before closing the PR.

Since this PR is marked for closure, we should verify that PR #547 fully addresses the database management needs and that there are no dependent changes waiting on this implementation.

prediction_market_agent_tooling/config.py (2)

189-189: LGTM: Good practice using instance attribute

The change to use self.model_fields instead of the class attribute is a better practice as it makes the code more maintainable and allows for proper inheritance.


196-196: LGTM: Consistent with previous change

The change maintains consistency with model_dump_public and properly handles secret field access.

tests/tools/test_db_cache.py (2)

Line range hint 85-104: LGTM! Good use of immutable configuration update

The test correctly uses model_copy to create a new instance with disabled cache, maintaining immutability. The assertion message clearly explains the test's purpose.


11-11: Verify the new test fixture implementation

The consistent change from keys_with_sqlalchemy_db_url to session_keys_with_postgresql_proc_and_enabled_cache across all test functions suggests a significant change in the database handling approach. This change appears to move from SQLAlchemy to PostgreSQL with process-based management.

Let's verify the fixture implementation:

Also applies to: 30-30, 47-47, 64-64, 85-85, 109-109, 126-126, 144-144, 162-162, 187-187, 210-210, 233-233, 258-258, 305-305

✅ Verification successful

Fixture implementation is properly managed and test isolation is maintained

The fixture session_keys_with_postgresql_proc_and_enabled_cache is well-implemented in conftest.py using pytest-postgresql. It:

  • Uses PostgreSQLExecutor for process management
  • Employs DatabaseJanitor for clean setup/teardown
  • Properly configures PostgreSQL connection with user, host, port, and database name
  • Ensures test isolation through process-based PostgreSQL instances

The change from SQLAlchemy URL to PostgreSQL process is a valid improvement that maintains test isolation and provides better control over the database lifecycle.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check the fixture implementation and its usage
# Look for the fixture definition and any related configuration

# Search for fixture definition
echo "Searching for fixture definition..."
rg -A 10 "session_keys_with_postgresql_proc_and_enabled_cache"

# Search for PostgreSQL process handling
echo "Searching for PostgreSQL process handling..."
rg "PostgreSQLExecutor|postgresql_proc"

Length of output: 19913

prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py (6)

4-4: Imports are appropriate and utilized correctly.

The import statement includes all necessary modules and they are used in the code.


8-8: Proper import of DBManager for centralized database management.

Importing DBManager enhances modularity by centralizing database operations.


27-28: Constructor updated to use APIKeys and initialize DBManager.

The __init__ method now accepts an APIKeys instance and initializes DBManager, improving configuration management.


35-35: Delegating table creation to DBManager enhances maintainability.

Using self.db_manager.create_tables([RelevantNewsCacheModel]) centralizes table management.


42-42: Consistent session management using DBManager.

Replacing direct session creation with self.db_manager.get_session() ensures consistent and reliable session handling.


75-75: Aligned session usage in save method with updated pattern.

The save method now utilizes self.db_manager.get_session(), maintaining consistency across database operations.

Comment on lines +35 to +37
def json_deserializer(s: str) -> t.Any:
data = json.loads(s)
return replace_custom_stringified_objects(data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling and documentation.

The function should handle JSON decode errors and include documentation explaining its purpose and expected input/output.

 def json_deserializer(s: str) -> t.Any:
+    """
+    Deserialize a JSON string into Python objects, handling custom type deserialization.
+    
+    Args:
+        s: JSON string to deserialize
+    
+    Returns:
+        Deserialized Python object
+        
+    Raises:
+        json.JSONDecodeError: If the input string is not valid JSON
+    """
-    data = json.loads(s)
-    return replace_custom_stringified_objects(data)
+    try:
+        data = json.loads(s)
+        return replace_custom_stringified_objects(data)
+    except json.JSONDecodeError as e:
+        raise json.JSONDecodeError(
+            f"Failed to decode JSON string: {e.msg}",
+            e.doc,
+            e.pos,
+        ) from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def json_deserializer(s: str) -> t.Any:
data = json.loads(s)
return replace_custom_stringified_objects(data)
def json_deserializer(s: str) -> t.Any:
"""
Deserialize a JSON string into Python objects, handling custom type deserialization.
Args:
s: JSON string to deserialize
Returns:
Deserialized Python object
Raises:
json.JSONDecodeError: If the input string is not valid JSON
"""
try:
data = json.loads(s)
return replace_custom_stringified_objects(data)
except json.JSONDecodeError as e:
raise json.JSONDecodeError(
f"Failed to decode JSON string: {e.msg}",
e.doc,
e.pos,
) from e

Comment on lines +40 to +61
def replace_custom_stringified_objects(obj: t.Any) -> t.Any:
"""
Used to deserialize objects from `json_serializer_default_fn` into their proper form.
"""
if isinstance(obj, str):
if obj.startswith("DatetimeUTC::"):
iso_str = obj[len("DatetimeUTC::") :]
return DatetimeUTC.to_datetime_utc(iso_str)
elif obj.startswith("timedelta::"):
total_seconds_str = obj[len("timedelta::") :]
return timedelta(seconds=float(total_seconds_str))
elif obj.startswith("date::"):
iso_str = obj[len("date::") :]
return date.fromisoformat(iso_str)
else:
return obj
elif isinstance(obj, dict):
return {k: replace_custom_stringified_objects(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [replace_custom_stringified_objects(item) for item in obj]
else:
return obj
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance robustness and maintainability of custom type handling.

The current implementation using string prefixes could be improved by:

  1. Using constants for type prefixes
  2. Adding format validation
  3. Making the type handling more maintainable
+# At the top of the file
+class SerializerPrefix:
+    DATETIME_UTC = "DatetimeUTC::"
+    TIMEDELTA = "timedelta::"
+    DATE = "date::"
+
+def _parse_prefixed_value(value: str, prefix: str, parser: t.Callable[[str], t.Any]) -> tuple[bool, t.Any]:
+    """Helper to safely parse prefixed values"""
+    if value.startswith(prefix):
+        try:
+            return True, parser(value[len(prefix):])
+        except ValueError as e:
+            raise ValueError(f"Invalid format for {prefix.rstrip(':')}: {e}") from e
+    return False, value
+
 def replace_custom_stringified_objects(obj: t.Any) -> t.Any:
     """
     Used to deserialize objects from `json_serializer_default_fn` into their proper form.
     """
     if isinstance(obj, str):
-        if obj.startswith("DatetimeUTC::"):
-            iso_str = obj[len("DatetimeUTC::") :]
-            return DatetimeUTC.to_datetime_utc(iso_str)
-        elif obj.startswith("timedelta::"):
-            total_seconds_str = obj[len("timedelta::") :]
-            return timedelta(seconds=float(total_seconds_str))
-        elif obj.startswith("date::"):
-            iso_str = obj[len("date::") :]
-            return date.fromisoformat(iso_str)
-        else:
-            return obj
+        # Try each type parser in sequence
+        matched, result = _parse_prefixed_value(
+            obj, SerializerPrefix.DATETIME_UTC,
+            DatetimeUTC.to_datetime_utc
+        )
+        if matched:
+            return result
+            
+        matched, result = _parse_prefixed_value(
+            obj, SerializerPrefix.TIMEDELTA,
+            lambda x: timedelta(seconds=float(x))
+        )
+        if matched:
+            return result
+            
+        matched, result = _parse_prefixed_value(
+            obj, SerializerPrefix.DATE,
+            date.fromisoformat
+        )
+        if matched:
+            return result
+            
+        return obj

This refactoring:

  • Centralizes prefix definitions
  • Adds proper error handling for malformed values
  • Makes adding new types easier
  • Improves error messages

Consider using a registry pattern if more custom types need to be added in the future.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def replace_custom_stringified_objects(obj: t.Any) -> t.Any:
"""
Used to deserialize objects from `json_serializer_default_fn` into their proper form.
"""
if isinstance(obj, str):
if obj.startswith("DatetimeUTC::"):
iso_str = obj[len("DatetimeUTC::") :]
return DatetimeUTC.to_datetime_utc(iso_str)
elif obj.startswith("timedelta::"):
total_seconds_str = obj[len("timedelta::") :]
return timedelta(seconds=float(total_seconds_str))
elif obj.startswith("date::"):
iso_str = obj[len("date::") :]
return date.fromisoformat(iso_str)
else:
return obj
elif isinstance(obj, dict):
return {k: replace_custom_stringified_objects(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [replace_custom_stringified_objects(item) for item in obj]
else:
return obj
class SerializerPrefix:
DATETIME_UTC = "DatetimeUTC::"
TIMEDELTA = "timedelta::"
DATE = "date::"
def _parse_prefixed_value(value: str, prefix: str, parser: t.Callable[[str], t.Any]) -> tuple[bool, t.Any]:
"""Helper to safely parse prefixed values"""
if value.startswith(prefix):
try:
return True, parser(value[len(prefix):])
except ValueError as e:
raise ValueError(f"Invalid format for {prefix.rstrip(':')}: {e}") from e
return False, value
def replace_custom_stringified_objects(obj: t.Any) -> t.Any:
"""
Used to deserialize objects from `json_serializer_default_fn` into their proper form.
"""
if isinstance(obj, str):
# Try each type parser in sequence
matched, result = _parse_prefixed_value(
obj, SerializerPrefix.DATETIME_UTC,
DatetimeUTC.to_datetime_utc
)
if matched:
return result
matched, result = _parse_prefixed_value(
obj, SerializerPrefix.TIMEDELTA,
lambda x: timedelta(seconds=float(x))
)
if matched:
return result
matched, result = _parse_prefixed_value(
obj, SerializerPrefix.DATE,
date.fromisoformat
)
if matched:
return result
return obj
elif isinstance(obj, dict):
return {k: replace_custom_stringified_objects(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [replace_custom_stringified_objects(item) for item in obj]
else:
return obj

Comment on lines +803 to +809
def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
return ContractPrediction(
publisher=values[0],
ipfs_hash=values[1],
tx_hashes=values[2],
estimated_probability_bps=values[3],
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation and documentation for the tuple structure.

The from_tuple method could benefit from additional validation and documentation:

    @staticmethod
-    def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
-        return ContractPrediction(
-            publisher=values[0],
-            ipfs_hash=values[1],
-            tx_hashes=values[2],
-            estimated_probability_bps=values[3],
-        )
+    def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
+        """Create a ContractPrediction instance from a tuple.
+        
+        Args:
+            values: A tuple containing (publisher: str, ipfs_hash: HexBytes,
+                    tx_hashes: list[HexBytes], estimated_probability_bps: int)
+        
+        Raises:
+            ValueError: If the tuple doesn't have exactly 4 elements
+            TypeError: If the elements don't match the expected types
+        """
+        if len(values) != 4:
+            raise ValueError(f"Expected 4 values, got {len(values)}")
+        
+        publisher, ipfs_hash, tx_hashes, probability_bps = values
+        
+        if not isinstance(probability_bps, int):
+            raise TypeError(f"estimated_probability_bps must be an integer, got {type(probability_bps)}")
+        
+        return ContractPrediction(
+            publisher=publisher,
+            ipfs_hash=ipfs_hash,
+            tx_hashes=tx_hashes,
+            estimated_probability_bps=probability_bps,
+        )

The changes:

  1. Add docstring explaining the expected tuple structure
  2. Add tuple length validation
  3. Add type validation for critical fields
  4. Use tuple unpacking for better readability
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
return ContractPrediction(
publisher=values[0],
ipfs_hash=values[1],
tx_hashes=values[2],
estimated_probability_bps=values[3],
)
@staticmethod
def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
"""Create a ContractPrediction instance from a tuple.
Args:
values: A tuple containing (publisher: str, ipfs_hash: HexBytes,
tx_hashes: list[HexBytes], estimated_probability_bps: int)
Raises:
ValueError: If the tuple doesn't have exactly 4 elements
TypeError: If the elements don't match the expected types
"""
if len(values) != 4:
raise ValueError(f"Expected 4 values, got {len(values)}")
publisher, ipfs_hash, tx_hashes, probability_bps = values
if not isinstance(probability_bps, int):
raise TypeError(f"estimated_probability_bps must be an integer, got {type(probability_bps)}")
return ContractPrediction(
publisher=publisher,
ipfs_hash=ipfs_hash,
tx_hashes=tx_hashes,
estimated_probability_bps=probability_bps,
)

Comment on lines 18 to 25
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager":
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url
secret_value = sqlalchemy_db_url.get_secret_value()
url_hash = hashlib.md5(secret_value.encode()).hexdigest()
if url_hash not in cls._instances:
instance = super(DBManager, cls).__new__(cls)
cls._instances[url_hash] = instance
return cls._instances[url_hash]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential Thread-Safety Issue in Singleton Implementation

The _instances dictionary used for implementing the singleton pattern is not thread-safe. In a multi-threaded environment, simultaneous access to _instances without synchronization could lead to race conditions, resulting in multiple instances of DBManager for the same database URL.

Consider adding a threading lock to synchronize access:

+import threading

 class DBManager:
+    _lock = threading.Lock()
     _instances: dict[str, "DBManager"] = {}

     def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager":
         sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url
         secret_value = sqlalchemy_db_url.get_secret_value()
         url_hash = hashlib.md5(secret_value.encode()).hexdigest()
+        with cls._lock:
             if url_hash not in cls._instances:
                 instance = super(DBManager, cls).__new__(cls)
                 cls._instances[url_hash] = instance
         return cls._instances[url_hash]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager":
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url
secret_value = sqlalchemy_db_url.get_secret_value()
url_hash = hashlib.md5(secret_value.encode()).hexdigest()
if url_hash not in cls._instances:
instance = super(DBManager, cls).__new__(cls)
cls._instances[url_hash] = instance
return cls._instances[url_hash]
_lock = threading.Lock()
_instances: dict[str, "DBManager"] = {}
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager":
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url
secret_value = sqlalchemy_db_url.get_secret_value()
url_hash = hashlib.md5(secret_value.encode()).hexdigest()
with cls._lock:
if url_hash not in cls._instances:
instance = super(DBManager, cls).__new__(cls)
cls._instances[url_hash] = instance
return cls._instances[url_hash]


# Create table if it doesn't exist
SQLModel.metadata.create_all(engine)
DBManager(api_keys).create_tables([FunctionCache])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ensure table creation occurs only once during application startup

Calling create_tables([FunctionCache]) inside the wrapper function results in attempting to create the table every time a cached function is called, which can be inefficient. It's better to initialize the necessary tables once during the application startup or module import.

Consider moving the table creation logic outside the wrapper function. For example:

 def db_cache(
     func: FunctionT | None = None,
     *,
     max_age: timedelta | None = None,
     cache_none: bool = True,
     api_keys: APIKeys | None = None,
     ignore_args: Sequence[str] | None = None,
     ignore_arg_types: Sequence[type] | None = None,
 ) -> FunctionT | Callable[[FunctionT], FunctionT]:
     if func is None:
         # [Existing code continues]

     api_keys = api_keys if api_keys is not None else APIKeys()
     db_manager = DBManager(api_keys)
+    # Ensure tables are created once
+    db_manager.create_tables([FunctionCache])

     @wraps(func)
     def wrapper(*args: Any, **kwargs: Any) -> Any:
         # [Existing code continues]

-        db_manager.create_tables([FunctionCache])
         # [Existing code continues]

If moving the table creation causes issues with tests or application startup, consider initializing the tables in a centralized location where the application context is first established.

Committable suggestion skipped: line range outside the PR's diff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Have only one instance of engine connection to Postgres
2 participants