-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First attempt at DB_Manager #555
base: main
Are you sure you want to change the base?
First attempt at DB_Manager #555
Conversation
WalkthroughThe pull request introduces multiple changes across several files, primarily focusing on enhancing the handling of database connections and serialization processes. Key modifications include the refactoring of the Changes
Possibly related PRs
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
|
||
@contextmanager | ||
def get_session(self) -> Session: | ||
session = Session(self._engine) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't
@contextmanager
def get_session(self) -> Session:
with Session(self._engine) as session:
yield session
have the same result, without doing all the stuff around it manually?
Closing this since this has less priority than the upcoming general agent milestone - also, this PR apparently solved the most pressing issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Outside diff range and nitpick comments (16)
tests/conftest.py (1)
22-23
: Consider making cache enablement configurableThe cache is currently hardcoded to be enabled. Consider making this configurable through a fixture parameter to allow testing both cached and non-cached scenarios.
@pytest.fixture(scope="session") def session_keys_with_postgresql_proc_and_enabled_cache( postgresql_proc: PostgreSQLExecutor, + enable_cache: bool = True, ) -> Generator[APIKeys, None, None]: with DatabaseJanitor( user=postgresql_proc.user, host=postgresql_proc.host, port=postgresql_proc.port, dbname=postgresql_proc.dbname, version=postgresql_proc.version, ): sqlalchemy_db_url = f"postgresql+psycopg2://{postgresql_proc.user}:@{postgresql_proc.host}:{postgresql_proc.port}/{postgresql_proc.dbname}" - yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=True) + yield APIKeys(SQLALCHEMY_DB_URL=SecretStr(sqlalchemy_db_url), ENABLE_CACHE=enable_cache)tests/tools/db/test_db_manager.py (3)
27-27
: Fix typo in assertion messageThere's a typo in the assertion message: "isntance" should be "instance".
- ), "DBManager returned isntance with a different SQLALCHEMY_DB_URL!" + ), "DBManager returned instance with a different SQLALCHEMY_DB_URL!"
8-28
: Consider enhancing singleton pattern test coverageWhile the test correctly verifies basic singleton behavior, consider adding these scenarios:
- Test thread safety of the singleton pattern
- Verify singleton state persistence
- Test edge cases with invalid DB URLs
Here's a suggested enhancement:
def test_DBManager_creates_only_one_instance() -> None: with tempfile.NamedTemporaryFile(suffix=".db") as temp_db1, \ tempfile.NamedTemporaryFile(suffix=".db") as temp_db2: # Test basic singleton behavior db1 = DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}"))) db2 = DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db1.name}"))) db3 = DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr(f"sqlite:///{temp_db2.name}"))) # Verify instance equality assert db1 is db2, "DBManager created more than one instance!" assert db1 is not db3, "DBManager returned instance with a different SQLALCHEMY_DB_URL!" # Verify state persistence db1.some_state = "test" assert hasattr(db2, "some_state"), "Singleton state not preserved!" assert db2.some_state == "test", "Singleton state not shared!" # Test invalid URL handling with pytest.raises(ValueError): DBManager(APIKeys(SQLALCHEMY_DB_URL=SecretStr("invalid://url")))
30-35
: Consider removing or completing the commented test before closureSince this PR is marked for closure, it would be better to either:
- Remove the commented-out test completely, or
- Complete the test implementation if it's testing critical functionality
The current TODO and incomplete state adds unnecessary noise to the codebase.
prediction_market_agent_tooling/tools/caches/serializers.py (2)
10-11
: Add error handling for json.dumps.While the function is well-structured, it should handle potential JSON serialization errors to provide better error messages and prevent crashes.
def json_serializer(x: t.Any) -> str: - return json.dumps(x, default=json_serializer_default_fn) + try: + return json.dumps(x, default=json_serializer_default_fn) + except TypeError as e: + raise TypeError(f"Failed to serialize object: {e}") from e
14-32
: Improve error message specificity.The error message could be more helpful by including the actual type of the unsupported value.
elif isinstance(y, BaseModel): return y.model_dump() raise TypeError( - f"Unsuported type for the default json serialize function, value is {y}." + f"Unsupported type '{type(y).__name__}' for JSON serialization. Value: {y}" )tests_integration/tools/test_relevant_news_analysis.py (1)
67-69
: Consider extracting the test database URL to a fixture.While the current implementation using
SQLite
in-memory database is appropriate for testing, consider moving the database configuration to a pytest fixture for better reusability across test files.Example implementation:
@pytest.fixture def test_db_config(): return APIKeys(SQLALCHEMY_DB_URL=SecretStr("sqlite:///:memory:")) def test_get_certified_relevant_news_since_cached(test_db_config) -> None: cache = RelevantNewsResponseCache(test_db_config)prediction_market_agent_tooling/config.py (1)
Line range hint
199-205
: Consider caching the ownership check resultWhile the validation and ownership check are correct, consider these improvements:
- Cache the ownership check result as it's unlikely to change frequently and the SafeV141 instantiation could be expensive.
- Add logging for security audit purposes when ownership checks fail.
Example implementation:
from functools import lru_cache import logging logger = logging.getLogger(__name__) def check_if_is_safe_owner(self, ethereum_client: EthereumClient) -> bool: if not self.SAFE_ADDRESS: raise ValueError("Cannot check ownership if safe_address is not defined.") @lru_cache(maxsize=1) def _check_ownership(safe_address: str, public_key: str) -> bool: s = SafeV141(safe_address, ethereum_client) is_owner = s.retrieve_is_owner(public_key) if not is_owner: logger.warning(f"Ownership check failed for address {public_key}") return is_owner public_key_from_signer = private_key_to_public_key(self.bet_from_private_key) return _check_ownership(self.SAFE_ADDRESS, public_key_from_signer)tests/tools/test_db_cache.py (2)
Line range hint
258-304
: Refactor cache invalidation test to improve clarityThe current implementation uses class redefinition and multiple type ignore comments, which makes the test harder to maintain and understand.
Consider refactoring to use two separate, clearly named classes:
- # Initial output model - class FirstOutputModel(TestOutputModel): - pass + class OriginalOutputModel(TestOutputModel): + pass + + class ModifiedOutputModel(TestOutputModel): + new_field: str - @db_cache(api_keys=session_keys_with_postgresql_proc_and_enabled_cache) - def multiply_models(a: TestInputModel, b: TestInputModel) -> FirstOutputModel: + @db_cache(api_keys=session_keys_with_postgresql_proc_and_enabled_cache) + def multiply_models(a: TestInputModel, b: TestInputModel) -> OriginalOutputModel | ModifiedOutputModel:This approach:
- Eliminates the need for type ignore comments
- Makes the test's intention clearer
- Better represents real-world model evolution scenarios
Line range hint
1-324
: Consider adding tests for error handling and edge casesWhile the test suite comprehensively covers various data types and caching scenarios, consider adding tests for:
- Error handling:
- Database connection failures
- Serialization errors for complex objects
- Edge cases:
- Very large objects that might hit size limits
- Special characters in strings that might affect serialization
Would you like me to help generate these additional test cases?
prediction_market_agent_tooling/tools/db/db_manager.py (3)
31-38
: Configureecho
Parameter Appropriately increate_engine
The
echo=True
parameter increate_engine
enables SQLAlchemy to log all SQL statements, which may not be suitable for a production environment due to potential performance impacts and security concerns. It's advisable to setecho
toFalse
or make it configurable.You can modify the code to make
echo
configurable:self._engine = create_engine( sqlalchemy_db_url.get_secret_value(), json_serializer=json_serializer, json_deserializer=json_deserializer, pool_size=20, pool_recycle=3600, - echo=True, + echo=False, )Alternatively, you can use an environment variable or configuration setting to control this parameter.
54-69
: Simplifycreate_tables
Method for Improved ReadabilityThe current implementation of the
create_tables
method is complex due to nested list comprehensions and the use of the walrus operator. This can make the code difficult to read and maintain.Refactor the method to enhance clarity:
def create_tables( self, sqlmodel_tables: Sequence[type[SQLModel]] | None = None ) -> None: - tables_to_create = ( - [ - table - for sqlmodel_table in sqlmodel_tables - if not self.cache_table_initialized.get( - ( - table := SQLModel.metadata.tables[ - cast(str, sqlmodel_table.__tablename__) - ] - ).name - ) - ] - if sqlmodel_tables is not None - else None - ) + if sqlmodel_tables is not None: + tables_to_create = [] + for sqlmodel_table in sqlmodel_tables: + table_name = cast(str, sqlmodel_table.__tablename__) + table = SQLModel.metadata.tables[table_name] + if not self.cache_table_initialized.get(table.name): + tables_to_create.append(table) + else: + tables_to_create = None SQLModel.metadata.create_all(self._engine, tables=tables_to_create) for table in tables_to_create or []: self.cache_table_initialized[table.name] = TrueThis refactored version reduces complexity and enhances readability by using straightforward control structures.
15-72
: Add Docstrings to Class and Methods for Better DocumentationThe
DBManager
class and its methods lack docstrings. Adding docstrings would improve code readability and help other developers understand the purpose and usage of each component.Consider adding docstrings like the following:
class DBManager: """Manages database connections and sessions using SQLAlchemy and SQLModel.""" def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager": """Implements a singleton pattern based on the database URL.""" def __init__(self, api_keys: APIKeys | None = None) -> None: """Initializes the database engine with custom serializers.""" @contextmanager def get_session(self) -> Generator[Session, None, None]: """Provides a session context manager.""" @contextmanager def get_connection(self) -> Generator[Connection, None, None]: """Provides a connection context manager.""" def create_tables( self, sqlmodel_tables: Sequence[type[SQLModel]] | None = None ) -> None: """Creates tables if they have not been initialized."""prediction_market_agent_tooling/tools/caches/db_cache.py (3)
101-101
: Refactor to reuseDBManager
instance within the decoratorCurrently,
DBManager(api_keys)
is instantiated multiple times within thedb_cache
decorator, specifically at lines 101, 146, and 199. Creating a single instance ofDBManager
and reusing it throughout the decorator can improve efficiency by reducing redundant object creation and clarifying the code structure.Apply this diff to refactor the code:
def db_cache( func: FunctionT | None = None, *, max_age: timedelta | None = None, cache_none: bool = True, api_keys: APIKeys | None = None, ignore_args: Sequence[str] | None = None, ignore_arg_types: Sequence[type] | None = None, ) -> FunctionT | Callable[[FunctionT], FunctionT]: if func is None: # Ugly Pythonic way to support this decorator as `@postgres_cache` but also `@postgres_cache(max_age=timedelta(days=3))` def decorator(func: FunctionT) -> FunctionT: return db_cache( func, max_age=max_age, cache_none=cache_none, api_keys=api_keys, ignore_args=ignore_args, ignore_arg_types=ignore_arg_types, ) return decorator api_keys = api_keys if api_keys is not None else APIKeys() + db_manager = DBManager(api_keys) @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: # If caching is disabled, just call the function and return it if not api_keys.ENABLE_CACHE: return func(*args, **kwargs) - DBManager(api_keys).create_tables([FunctionCache]) + db_manager.create_tables([FunctionCache]) # Convert *args and **kwargs to a single dictionary... # [Existing code continues] # Database session for retrieving cached result - with DBManager(api_keys).get_session() as session: + with db_manager.get_session() as session: # Try to get cached result # [Existing code continues] # On cache miss, compute the result computed_result = func(*args, **kwargs) # [Logging and other logic] # Save the computed result to cache if applicable if cache_none or computed_result is not None: cache_entry = FunctionCache( function_name=function_name, full_function_name=full_function_name, args_hash=args_hash, args=args_dict, result=computed_result, created_at=utcnow(), ) - with DBManager(api_keys).get_session() as session: + with db_manager.get_session() as session: logger.info(f"Saving {cache_entry} into database.") session.add(cache_entry) session.commit() return computed_result return cast(FunctionT, wrapper)Also applies to: 146-146, 199-199
142-144
: Simplify the check for Pydantic models in the return typeThe condition to check if the return type contains a Pydantic model can be simplified for readability. Since
contains_pydantic_model
already checks forNone
, thereturn_type is not None
check is redundant.Apply this diff to simplify the condition:
- is_pydantic_model = return_type is not None and contains_pydantic_model( - return_type - ) + is_pydantic_model = contains_pydantic_model(return_type)
146-146
: Add exception handling for database session operationsWhile using the
with
statement ensures that sessions are properly closed, it's prudent to add exception handling around database operations to catch and handle potential exceptions such as connection errors or data integrity issues. This enhances the robustness of the application.Consider wrapping database interactions in try-except blocks:
with db_manager.get_session() as session: + try: # Database operations + except Exception as e: + logger.error(f"Database operation failed: {e}") + raiseAlso applies to: 199-199
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
⛔ Files ignored due to path filters (2)
poetry.lock
is excluded by!**/*.lock
,!**/*.lock
pyproject.toml
is excluded by!**/*.toml
📒 Files selected for processing (10)
prediction_market_agent_tooling/config.py
(1 hunks)prediction_market_agent_tooling/markets/omen/data_models.py
(1 hunks)prediction_market_agent_tooling/tools/caches/db_cache.py
(5 hunks)prediction_market_agent_tooling/tools/caches/serializers.py
(1 hunks)prediction_market_agent_tooling/tools/db/db_manager.py
(1 hunks)prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py
(3 hunks)tests/conftest.py
(1 hunks)tests/tools/db/test_db_manager.py
(1 hunks)tests/tools/test_db_cache.py
(15 hunks)tests_integration/tools/test_relevant_news_analysis.py
(2 hunks)
🔇 Additional comments (18)
tests/conftest.py (3)
5-6
: LGTM! Well-organized imports
The new imports for PostgreSQL testing utilities follow proper import ordering conventions and are appropriate for the implemented functionality.
15-21
: LGTM! Proper database lifecycle management
The use of DatabaseJanitor ensures proper cleanup of database resources during testing. The parameters are correctly derived from postgresql_proc.
11-13
: Verify fixture updates across test files
This is a breaking change as the fixture name has changed from keys_with_sqlalchemy_db_url
to session_keys_with_postgresql_proc_and_enabled_cache
. Let's verify that all test files have been updated accordingly.
✅ Verification successful
Fixture update verification completed successfully
The search results show that:
- There are no remaining references to the old fixture name
keys_with_sqlalchemy_db_url
in the tests directory - The new fixture name
session_keys_with_postgresql_proc_and_enabled_cache
is properly used across test files, particularly intests/tools/test_db_cache.py
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for any remaining usage of the old fixture name
echo "Checking for old fixture usage..."
rg "keys_with_sqlalchemy_db_url" tests/
# Search for usage of the new fixture name to ensure it's being used
echo "Checking new fixture usage..."
rg "session_keys_with_postgresql_proc_and_enabled_cache" tests/
Length of output: 3539
tests/tools/db/test_db_manager.py (1)
1-6
: LGTM! Well-structured imports.
The imports are appropriate and follow security best practices by using SecretStr for database URLs.
prediction_market_agent_tooling/tools/caches/serializers.py (2)
1-8
: LGTM! Well-organized imports.
The imports follow good practices with standard library imports first, followed by third-party and local imports. All imports are utilized in the code.
1-61
: Verify bidirectional serialization for all supported types.
Let's ensure that all supported types can be correctly serialized and deserialized back to their original form.
✅ Verification successful
Serialization system is correctly implemented and used
The verification shows that the serialization system properly handles all supported types:
DatetimeUTC
: Correctly serialized with ISO format and prefixtimedelta
: Properly handled using total secondsdate
: Serialized using ISO formatBaseModel
: Serialized using model_dump()
The implementation is used consistently throughout the codebase, primarily in:
- Database management (
db_manager.py
) - Cache implementations
- Market data models
- API responses
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Find all usages of the serialization functions to verify type coverage
# Look for direct usage of serialization functions
echo "Direct usage of serializers:"
rg -l "json_serializer|json_deserializer" --type python
# Look for the custom types being serialized
echo -e "\nCustom type usage:"
rg -l "DatetimeUTC|timedelta|BaseModel" --type python
# Look for potential new types that might need serialization
echo -e "\nPotential new types to consider:"
ast-grep --pattern 'class $name(BaseModel)' --type python
Length of output: 616
Script:
#!/bin/bash
# Description: Find all usages of the serialization functions to verify type coverage
# Look for direct usage of serialization functions
echo "Direct usage of serializers:"
rg "json_serializer|json_deserializer"
# Look for the custom types being serialized
echo -e "\nCustom type usage:"
rg "DatetimeUTC|timedelta|BaseModel"
# Look for potential new types that might need serialization
echo -e "\nPotential new types to consider:"
ast-grep 'class $name(BaseModel)'
Length of output: 50975
tests_integration/tools/test_relevant_news_analysis.py (2)
5-5
: LGTM! Good security practice with SecretStr.
The addition of SecretStr
from pydantic and APIKeys
improves the handling of sensitive configuration data.
Also applies to: 7-7
67-69
: Verify dependencies before closing the PR.
Since this PR is marked for closure, we should verify that PR #547 fully addresses the database management needs and that there are no dependent changes waiting on this implementation.
prediction_market_agent_tooling/config.py (2)
189-189
: LGTM: Good practice using instance attribute
The change to use self.model_fields
instead of the class attribute is a better practice as it makes the code more maintainable and allows for proper inheritance.
196-196
: LGTM: Consistent with previous change
The change maintains consistency with model_dump_public
and properly handles secret field access.
tests/tools/test_db_cache.py (2)
Line range hint 85-104
: LGTM! Good use of immutable configuration update
The test correctly uses model_copy
to create a new instance with disabled cache, maintaining immutability. The assertion message clearly explains the test's purpose.
11-11
: Verify the new test fixture implementation
The consistent change from keys_with_sqlalchemy_db_url
to session_keys_with_postgresql_proc_and_enabled_cache
across all test functions suggests a significant change in the database handling approach. This change appears to move from SQLAlchemy to PostgreSQL with process-based management.
Let's verify the fixture implementation:
Also applies to: 30-30, 47-47, 64-64, 85-85, 109-109, 126-126, 144-144, 162-162, 187-187, 210-210, 233-233, 258-258, 305-305
✅ Verification successful
Fixture implementation is properly managed and test isolation is maintained
The fixture session_keys_with_postgresql_proc_and_enabled_cache
is well-implemented in conftest.py
using pytest-postgresql
. It:
- Uses
PostgreSQLExecutor
for process management - Employs
DatabaseJanitor
for clean setup/teardown - Properly configures PostgreSQL connection with user, host, port, and database name
- Ensures test isolation through process-based PostgreSQL instances
The change from SQLAlchemy URL to PostgreSQL process is a valid improvement that maintains test isolation and provides better control over the database lifecycle.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check the fixture implementation and its usage
# Look for the fixture definition and any related configuration
# Search for fixture definition
echo "Searching for fixture definition..."
rg -A 10 "session_keys_with_postgresql_proc_and_enabled_cache"
# Search for PostgreSQL process handling
echo "Searching for PostgreSQL process handling..."
rg "PostgreSQLExecutor|postgresql_proc"
Length of output: 19913
prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py (6)
4-4
: Imports are appropriate and utilized correctly.
The import statement includes all necessary modules and they are used in the code.
8-8
: Proper import of DBManager
for centralized database management.
Importing DBManager
enhances modularity by centralizing database operations.
27-28
: Constructor updated to use APIKeys
and initialize DBManager
.
The __init__
method now accepts an APIKeys
instance and initializes DBManager
, improving configuration management.
35-35
: Delegating table creation to DBManager
enhances maintainability.
Using self.db_manager.create_tables([RelevantNewsCacheModel])
centralizes table management.
42-42
: Consistent session management using DBManager
.
Replacing direct session creation with self.db_manager.get_session()
ensures consistent and reliable session handling.
75-75
: Aligned session usage in save
method with updated pattern.
The save
method now utilizes self.db_manager.get_session()
, maintaining consistency across database operations.
def json_deserializer(s: str) -> t.Any: | ||
data = json.loads(s) | ||
return replace_custom_stringified_objects(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling and documentation.
The function should handle JSON decode errors and include documentation explaining its purpose and expected input/output.
def json_deserializer(s: str) -> t.Any:
+ """
+ Deserialize a JSON string into Python objects, handling custom type deserialization.
+
+ Args:
+ s: JSON string to deserialize
+
+ Returns:
+ Deserialized Python object
+
+ Raises:
+ json.JSONDecodeError: If the input string is not valid JSON
+ """
- data = json.loads(s)
- return replace_custom_stringified_objects(data)
+ try:
+ data = json.loads(s)
+ return replace_custom_stringified_objects(data)
+ except json.JSONDecodeError as e:
+ raise json.JSONDecodeError(
+ f"Failed to decode JSON string: {e.msg}",
+ e.doc,
+ e.pos,
+ ) from e
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def json_deserializer(s: str) -> t.Any: | |
data = json.loads(s) | |
return replace_custom_stringified_objects(data) | |
def json_deserializer(s: str) -> t.Any: | |
""" | |
Deserialize a JSON string into Python objects, handling custom type deserialization. | |
Args: | |
s: JSON string to deserialize | |
Returns: | |
Deserialized Python object | |
Raises: | |
json.JSONDecodeError: If the input string is not valid JSON | |
""" | |
try: | |
data = json.loads(s) | |
return replace_custom_stringified_objects(data) | |
except json.JSONDecodeError as e: | |
raise json.JSONDecodeError( | |
f"Failed to decode JSON string: {e.msg}", | |
e.doc, | |
e.pos, | |
) from e |
def replace_custom_stringified_objects(obj: t.Any) -> t.Any: | ||
""" | ||
Used to deserialize objects from `json_serializer_default_fn` into their proper form. | ||
""" | ||
if isinstance(obj, str): | ||
if obj.startswith("DatetimeUTC::"): | ||
iso_str = obj[len("DatetimeUTC::") :] | ||
return DatetimeUTC.to_datetime_utc(iso_str) | ||
elif obj.startswith("timedelta::"): | ||
total_seconds_str = obj[len("timedelta::") :] | ||
return timedelta(seconds=float(total_seconds_str)) | ||
elif obj.startswith("date::"): | ||
iso_str = obj[len("date::") :] | ||
return date.fromisoformat(iso_str) | ||
else: | ||
return obj | ||
elif isinstance(obj, dict): | ||
return {k: replace_custom_stringified_objects(v) for k, v in obj.items()} | ||
elif isinstance(obj, list): | ||
return [replace_custom_stringified_objects(item) for item in obj] | ||
else: | ||
return obj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance robustness and maintainability of custom type handling.
The current implementation using string prefixes could be improved by:
- Using constants for type prefixes
- Adding format validation
- Making the type handling more maintainable
+# At the top of the file
+class SerializerPrefix:
+ DATETIME_UTC = "DatetimeUTC::"
+ TIMEDELTA = "timedelta::"
+ DATE = "date::"
+
+def _parse_prefixed_value(value: str, prefix: str, parser: t.Callable[[str], t.Any]) -> tuple[bool, t.Any]:
+ """Helper to safely parse prefixed values"""
+ if value.startswith(prefix):
+ try:
+ return True, parser(value[len(prefix):])
+ except ValueError as e:
+ raise ValueError(f"Invalid format for {prefix.rstrip(':')}: {e}") from e
+ return False, value
+
def replace_custom_stringified_objects(obj: t.Any) -> t.Any:
"""
Used to deserialize objects from `json_serializer_default_fn` into their proper form.
"""
if isinstance(obj, str):
- if obj.startswith("DatetimeUTC::"):
- iso_str = obj[len("DatetimeUTC::") :]
- return DatetimeUTC.to_datetime_utc(iso_str)
- elif obj.startswith("timedelta::"):
- total_seconds_str = obj[len("timedelta::") :]
- return timedelta(seconds=float(total_seconds_str))
- elif obj.startswith("date::"):
- iso_str = obj[len("date::") :]
- return date.fromisoformat(iso_str)
- else:
- return obj
+ # Try each type parser in sequence
+ matched, result = _parse_prefixed_value(
+ obj, SerializerPrefix.DATETIME_UTC,
+ DatetimeUTC.to_datetime_utc
+ )
+ if matched:
+ return result
+
+ matched, result = _parse_prefixed_value(
+ obj, SerializerPrefix.TIMEDELTA,
+ lambda x: timedelta(seconds=float(x))
+ )
+ if matched:
+ return result
+
+ matched, result = _parse_prefixed_value(
+ obj, SerializerPrefix.DATE,
+ date.fromisoformat
+ )
+ if matched:
+ return result
+
+ return obj
This refactoring:
- Centralizes prefix definitions
- Adds proper error handling for malformed values
- Makes adding new types easier
- Improves error messages
Consider using a registry pattern if more custom types need to be added in the future.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def replace_custom_stringified_objects(obj: t.Any) -> t.Any: | |
""" | |
Used to deserialize objects from `json_serializer_default_fn` into their proper form. | |
""" | |
if isinstance(obj, str): | |
if obj.startswith("DatetimeUTC::"): | |
iso_str = obj[len("DatetimeUTC::") :] | |
return DatetimeUTC.to_datetime_utc(iso_str) | |
elif obj.startswith("timedelta::"): | |
total_seconds_str = obj[len("timedelta::") :] | |
return timedelta(seconds=float(total_seconds_str)) | |
elif obj.startswith("date::"): | |
iso_str = obj[len("date::") :] | |
return date.fromisoformat(iso_str) | |
else: | |
return obj | |
elif isinstance(obj, dict): | |
return {k: replace_custom_stringified_objects(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [replace_custom_stringified_objects(item) for item in obj] | |
else: | |
return obj | |
class SerializerPrefix: | |
DATETIME_UTC = "DatetimeUTC::" | |
TIMEDELTA = "timedelta::" | |
DATE = "date::" | |
def _parse_prefixed_value(value: str, prefix: str, parser: t.Callable[[str], t.Any]) -> tuple[bool, t.Any]: | |
"""Helper to safely parse prefixed values""" | |
if value.startswith(prefix): | |
try: | |
return True, parser(value[len(prefix):]) | |
except ValueError as e: | |
raise ValueError(f"Invalid format for {prefix.rstrip(':')}: {e}") from e | |
return False, value | |
def replace_custom_stringified_objects(obj: t.Any) -> t.Any: | |
""" | |
Used to deserialize objects from `json_serializer_default_fn` into their proper form. | |
""" | |
if isinstance(obj, str): | |
# Try each type parser in sequence | |
matched, result = _parse_prefixed_value( | |
obj, SerializerPrefix.DATETIME_UTC, | |
DatetimeUTC.to_datetime_utc | |
) | |
if matched: | |
return result | |
matched, result = _parse_prefixed_value( | |
obj, SerializerPrefix.TIMEDELTA, | |
lambda x: timedelta(seconds=float(x)) | |
) | |
if matched: | |
return result | |
matched, result = _parse_prefixed_value( | |
obj, SerializerPrefix.DATE, | |
date.fromisoformat | |
) | |
if matched: | |
return result | |
return obj | |
elif isinstance(obj, dict): | |
return {k: replace_custom_stringified_objects(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [replace_custom_stringified_objects(item) for item in obj] | |
else: | |
return obj |
def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction": | ||
return ContractPrediction( | ||
publisher=values[0], | ||
ipfs_hash=values[1], | ||
tx_hashes=values[2], | ||
estimated_probability_bps=values[3], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add validation and documentation for the tuple structure.
The from_tuple
method could benefit from additional validation and documentation:
@staticmethod
- def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
- return ContractPrediction(
- publisher=values[0],
- ipfs_hash=values[1],
- tx_hashes=values[2],
- estimated_probability_bps=values[3],
- )
+ def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction":
+ """Create a ContractPrediction instance from a tuple.
+
+ Args:
+ values: A tuple containing (publisher: str, ipfs_hash: HexBytes,
+ tx_hashes: list[HexBytes], estimated_probability_bps: int)
+
+ Raises:
+ ValueError: If the tuple doesn't have exactly 4 elements
+ TypeError: If the elements don't match the expected types
+ """
+ if len(values) != 4:
+ raise ValueError(f"Expected 4 values, got {len(values)}")
+
+ publisher, ipfs_hash, tx_hashes, probability_bps = values
+
+ if not isinstance(probability_bps, int):
+ raise TypeError(f"estimated_probability_bps must be an integer, got {type(probability_bps)}")
+
+ return ContractPrediction(
+ publisher=publisher,
+ ipfs_hash=ipfs_hash,
+ tx_hashes=tx_hashes,
+ estimated_probability_bps=probability_bps,
+ )
The changes:
- Add docstring explaining the expected tuple structure
- Add tuple length validation
- Add type validation for critical fields
- Use tuple unpacking for better readability
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction": | |
return ContractPrediction( | |
publisher=values[0], | |
ipfs_hash=values[1], | |
tx_hashes=values[2], | |
estimated_probability_bps=values[3], | |
) | |
@staticmethod | |
def from_tuple(values: tuple[t.Any, ...]) -> "ContractPrediction": | |
"""Create a ContractPrediction instance from a tuple. | |
Args: | |
values: A tuple containing (publisher: str, ipfs_hash: HexBytes, | |
tx_hashes: list[HexBytes], estimated_probability_bps: int) | |
Raises: | |
ValueError: If the tuple doesn't have exactly 4 elements | |
TypeError: If the elements don't match the expected types | |
""" | |
if len(values) != 4: | |
raise ValueError(f"Expected 4 values, got {len(values)}") | |
publisher, ipfs_hash, tx_hashes, probability_bps = values | |
if not isinstance(probability_bps, int): | |
raise TypeError(f"estimated_probability_bps must be an integer, got {type(probability_bps)}") | |
return ContractPrediction( | |
publisher=publisher, | |
ipfs_hash=ipfs_hash, | |
tx_hashes=tx_hashes, | |
estimated_probability_bps=probability_bps, | |
) |
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager": | ||
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url | ||
secret_value = sqlalchemy_db_url.get_secret_value() | ||
url_hash = hashlib.md5(secret_value.encode()).hexdigest() | ||
if url_hash not in cls._instances: | ||
instance = super(DBManager, cls).__new__(cls) | ||
cls._instances[url_hash] = instance | ||
return cls._instances[url_hash] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential Thread-Safety Issue in Singleton Implementation
The _instances
dictionary used for implementing the singleton pattern is not thread-safe. In a multi-threaded environment, simultaneous access to _instances
without synchronization could lead to race conditions, resulting in multiple instances of DBManager
for the same database URL.
Consider adding a threading lock to synchronize access:
+import threading
class DBManager:
+ _lock = threading.Lock()
_instances: dict[str, "DBManager"] = {}
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager":
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url
secret_value = sqlalchemy_db_url.get_secret_value()
url_hash = hashlib.md5(secret_value.encode()).hexdigest()
+ with cls._lock:
if url_hash not in cls._instances:
instance = super(DBManager, cls).__new__(cls)
cls._instances[url_hash] = instance
return cls._instances[url_hash]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager": | |
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url | |
secret_value = sqlalchemy_db_url.get_secret_value() | |
url_hash = hashlib.md5(secret_value.encode()).hexdigest() | |
if url_hash not in cls._instances: | |
instance = super(DBManager, cls).__new__(cls) | |
cls._instances[url_hash] = instance | |
return cls._instances[url_hash] | |
_lock = threading.Lock() | |
_instances: dict[str, "DBManager"] = {} | |
def __new__(cls, api_keys: APIKeys | None = None) -> "DBManager": | |
sqlalchemy_db_url = (api_keys or APIKeys()).sqlalchemy_db_url | |
secret_value = sqlalchemy_db_url.get_secret_value() | |
url_hash = hashlib.md5(secret_value.encode()).hexdigest() | |
with cls._lock: | |
if url_hash not in cls._instances: | |
instance = super(DBManager, cls).__new__(cls) | |
cls._instances[url_hash] = instance | |
return cls._instances[url_hash] |
|
||
# Create table if it doesn't exist | ||
SQLModel.metadata.create_all(engine) | ||
DBManager(api_keys).create_tables([FunctionCache]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Ensure table creation occurs only once during application startup
Calling create_tables([FunctionCache])
inside the wrapper
function results in attempting to create the table every time a cached function is called, which can be inefficient. It's better to initialize the necessary tables once during the application startup or module import.
Consider moving the table creation logic outside the wrapper
function. For example:
def db_cache(
func: FunctionT | None = None,
*,
max_age: timedelta | None = None,
cache_none: bool = True,
api_keys: APIKeys | None = None,
ignore_args: Sequence[str] | None = None,
ignore_arg_types: Sequence[type] | None = None,
) -> FunctionT | Callable[[FunctionT], FunctionT]:
if func is None:
# [Existing code continues]
api_keys = api_keys if api_keys is not None else APIKeys()
db_manager = DBManager(api_keys)
+ # Ensure tables are created once
+ db_manager.create_tables([FunctionCache])
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# [Existing code continues]
- db_manager.create_tables([FunctionCache])
# [Existing code continues]
If moving the table creation causes issues with tests or application startup, consider initializing the tables in a centralized location where the application context is first established.
Committable suggestion skipped: line range outside the PR's diff.
No description provided.