From 9559eb6ef13e0ad335ae60ec9e961b42800d5541 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Mon, 30 Sep 2024 13:09:14 +0200 Subject: [PATCH] Add ruff rules UP(pyupgrade) --- .github/workflows/lint.yml | 4 +- Makefile | 4 +- astrapy/admin.py | 528 ++++++++-------- astrapy/api_commander.py | 90 ++- astrapy/api_options.py | 8 +- astrapy/authentication.py | 20 +- astrapy/client.py | 94 +-- astrapy/collection.py | 488 ++++++++------- astrapy/constants.py | 6 +- astrapy/core/api.py | 72 +-- astrapy/core/core_types.py | 4 +- astrapy/core/db.py | 577 +++++++++--------- astrapy/core/ops.py | 104 ++-- astrapy/core/utils.py | 80 ++- astrapy/cursors.py | 154 +++-- astrapy/database.py | 344 +++++------ astrapy/exceptions.py | 132 ++-- astrapy/info.py | 134 ++-- astrapy/meta.py | 14 +- astrapy/operations.py | 212 +++---- astrapy/request_tools.py | 10 +- astrapy/results.py | 18 +- astrapy/transform_payload.py | 26 +- astrapy/user_agents.py | 13 +- poetry.lock | 41 +- pyproject.toml | 8 +- scripts/astrapy_latest_interface.py | 3 +- tests/conftest.py | 22 +- tests/core/conftest.py | 14 +- tests/core/test_async_db_dml.py | 18 +- tests/core/test_async_db_dml_pagination.py | 3 +- tests/core/test_db_dml.py | 18 +- tests/core/test_db_dml_pagination.py | 3 +- tests/core/test_ops.py | 2 +- tests/idiomatic/integration/test_admin.py | 16 +- tests/idiomatic/integration/test_dml_async.py | 14 +- tests/idiomatic/integration/test_dml_sync.py | 8 +- .../integration/test_exceptions_async.py | 5 +- tests/idiomatic/unit/test_apicommander.py | 5 +- .../idiomatic/unit/test_collection_options.py | 4 +- .../unit/test_document_extractors.py | 4 +- tests/preprocess_env.py | 29 +- tests/vectorize_idiomatic/conftest.py | 6 +- .../test_vectorize_methods_async.py | 12 +- .../test_vectorize_methods_sync.py | 10 +- .../integration/test_vectorize_providers.py | 18 +- tests/vectorize_idiomatic/query_providers.py | 3 +- tests/vectorize_idiomatic/vectorize_models.py | 9 +- 48 files changed, 1679 insertions(+), 1732 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9c32d4a7..d066199a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -27,11 +27,11 @@ jobs: - name: Ruff Linting AstraPy run: | - poetry run ruff astrapy + poetry run ruff check astrapy - name: Ruff Linting Tests run: | - poetry run ruff tests + poetry run ruff check tests - name: Isort Linting AstraPy run: | diff --git a/Makefile b/Makefile index 147ee7f5..f58fd615 100644 --- a/Makefile +++ b/Makefile @@ -9,13 +9,13 @@ FMT_FLAGS ?= --check format: format-src format-tests format-tests: - poetry run ruff tests + poetry run ruff check tests poetry run isort tests $(FMT_FLAGS) --profile black poetry run black tests $(FMT_FLAGS) poetry run mypy tests format-src: - poetry run ruff astrapy + poetry run ruff check astrapy poetry run isort astrapy $(FMT_FLAGS) --profile black poetry run black astrapy $(FMT_FLAGS) poetry run mypy astrapy diff --git a/astrapy/admin.py b/astrapy/admin.py index 4438c3f2..3b80a7e2 100644 --- a/astrapy/admin.py +++ b/astrapy/admin.py @@ -21,7 +21,7 @@ import warnings from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any import deprecation @@ -107,7 +107,7 @@ class ParsedAPIEndpoint: environment: str -def parse_api_endpoint(api_endpoint: str) -> Optional[ParsedAPIEndpoint]: +def parse_api_endpoint(api_endpoint: str) -> ParsedAPIEndpoint | None: """ Parse an API Endpoint into a ParsedAPIEndpoint structure. @@ -140,7 +140,7 @@ def api_endpoint_parsing_error_message(failing_url: str) -> str: ) -def parse_generic_api_url(api_endpoint: str) -> Optional[str]: +def parse_generic_api_url(api_endpoint: str) -> str | None: """ Validate a generic API Endpoint string, such as `http://10.1.1.1:123` or `https://my.domain`. @@ -197,10 +197,10 @@ def build_api_endpoint(environment: str, database_id: str, region: str) -> str: def fetch_raw_database_info_from_id_token( id: str, *, - token: Optional[str], + token: str | None, environment: str = Environment.PROD, - max_time_ms: Optional[int] = None, -) -> Dict[str, Any]: + max_time_ms: int | None = None, +) -> dict[str, Any]: """ Fetch database information through the DevOps API and return it in full, exactly like the API gives it back. @@ -217,7 +217,7 @@ def fetch_raw_database_info_from_id_token( The full response from the DevOps API about the database. """ - ops_headers: Dict[str, str | None] + ops_headers: dict[str, str | None] if token is not None: ops_headers = { DEFAULT_DEV_OPS_AUTH_HEADER: f"{DEFAULT_DEV_OPS_AUTH_PREFIX}{token}", @@ -248,10 +248,10 @@ def fetch_raw_database_info_from_id_token( async def async_fetch_raw_database_info_from_id_token( id: str, *, - token: Optional[str], + token: str | None, environment: str = Environment.PROD, - max_time_ms: Optional[int] = None, -) -> Dict[str, Any]: + max_time_ms: int | None = None, +) -> dict[str, Any]: """ Fetch database information through the DevOps API and return it in full, exactly like the API gives it back. @@ -269,7 +269,7 @@ async def async_fetch_raw_database_info_from_id_token( The full response from the DevOps API about the database. """ - ops_headers: Dict[str, str | None] + ops_headers: dict[str, str | None] if token is not None: ops_headers = { DEFAULT_DEV_OPS_AUTH_HEADER: f"{DEFAULT_DEV_OPS_AUTH_PREFIX}{token}", @@ -299,11 +299,11 @@ async def async_fetch_raw_database_info_from_id_token( def fetch_database_info( api_endpoint: str, - token: Optional[str], - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, -) -> Optional[DatabaseInfo]: + token: str | None, + keyspace: str | None = None, + namespace: str | None = None, + max_time_ms: int | None = None, +) -> DatabaseInfo | None: """ Fetch database information through the DevOps API. @@ -353,11 +353,11 @@ def fetch_database_info( async def async_fetch_database_info( api_endpoint: str, - token: Optional[str], - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, -) -> Optional[DatabaseInfo]: + token: str | None, + keyspace: str | None = None, + namespace: str | None = None, + max_time_ms: int | None = None, +) -> DatabaseInfo | None: """ Fetch database information through the DevOps API. Async version of the function, for use in an asyncio context. @@ -407,7 +407,7 @@ async def async_fetch_database_info( def _recast_as_admin_database_info( - admin_database_info_dict: Dict[str, Any], + admin_database_info_dict: dict[str, Any], *, environment: str, ) -> AdminDatabaseInfo: @@ -443,10 +443,10 @@ def _recast_as_admin_database_info( def normalize_api_endpoint( id_or_endpoint: str, - region: Optional[str], + region: str | None, token: TokenProvider, environment: str, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> str: """ Ensure that a id(+region) / endpoint init signature is normalized into @@ -501,7 +501,7 @@ def normalize_api_endpoint( def normalize_id_endpoint_parameters( - id: Optional[str], api_endpoint: Optional[str] + id: str | None, api_endpoint: str | None ) -> str: if id is None: if api_endpoint is None: @@ -558,13 +558,13 @@ class AstraDBAdmin: def __init__( self, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> None: self.token_provider = coerce_token_provider(token) self.environment = (environment or Environment.PROD).lower() @@ -577,7 +577,7 @@ def __init__( self._dev_ops_url = dev_ops_url self._dev_ops_api_version = dev_ops_api_version - self._dev_ops_commander_headers: Dict[str, str | None] + self._dev_ops_commander_headers: dict[str, str | None] if self.token_provider: _token = self.token_provider.get_token() self._dev_ops_commander_headers = { @@ -591,12 +591,12 @@ def __init__( self._dev_ops_api_commander = self._get_dev_ops_api_commander() def __repr__(self) -> str: - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'"{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - env_desc: Optional[str] + env_desc: str | None if self.environment == Environment.PROD: env_desc = None else: @@ -640,12 +640,12 @@ def _get_dev_ops_api_commander(self) -> APICommander: def _copy( self, *, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> AstraDBAdmin: return AstraDBAdmin( token=coerce_token_provider(token) or self.token_provider, @@ -659,9 +659,9 @@ def _copy( def with_options( self, *, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDBAdmin: """ Create a clone of this AstraDBAdmin with some changed attributes. @@ -693,8 +693,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -722,7 +722,7 @@ def set_caller( def list_databases( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> CommandCursor[AdminDatabaseInfo]: """ Get the list of databases, as obtained with a request to the DevOps API. @@ -772,7 +772,7 @@ def list_databases( async def async_list_databases( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> CommandCursor[AdminDatabaseInfo]: """ Get the list of databases, as obtained with a request to the DevOps API. @@ -822,7 +822,7 @@ async def async_list_databases( ) def database_info( - self, id: str, *, max_time_ms: Optional[int] = None + self, id: str, *, max_time_ms: int | None = None ) -> AdminDatabaseInfo: """ Get the full information on a given database, through a request to the DevOps API. @@ -858,7 +858,7 @@ def database_info( ) async def async_database_info( - self, id: str, *, max_time_ms: Optional[int] = None + self, id: str, *, max_time_ms: int | None = None ) -> AdminDatabaseInfo: """ Get the full information on a given database, through a request to the DevOps API. @@ -899,10 +899,10 @@ def create_database( *, cloud_provider: str, region: str, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create a database as requested, optionally waiting for it to be ready. @@ -1012,10 +1012,10 @@ async def async_create_database( *, cloud_provider: str, region: str, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create a database as requested, optionally waiting for it to be ready. @@ -1129,8 +1129,8 @@ def drop_database( id: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a database, i.e. delete it completely and permanently with all its data. @@ -1180,8 +1180,8 @@ def drop_database( ) logger.info(f"DevOps API returned from dropping database '{id}'") if wait_until_active: - last_status_seen: Optional[str] = DEV_OPS_DATABASE_STATUS_TERMINATING - _db_name: Optional[str] = None + last_status_seen: str | None = DEV_OPS_DATABASE_STATUS_TERMINATING + _db_name: str | None = None while last_status_seen == DEV_OPS_DATABASE_STATUS_TERMINATING: logger.info(f"sleeping to poll for status of '{id}'") time.sleep(DEV_OPS_DATABASE_POLL_INTERVAL_S) @@ -1212,8 +1212,8 @@ async def async_drop_database( id: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a database, i.e. delete it completely and permanently with all its data. Async version of the method, for use in an asyncio context. @@ -1260,8 +1260,8 @@ async def async_drop_database( ) logger.info(f"DevOps API returned from dropping database '{id}', async") if wait_until_active: - last_status_seen: Optional[str] = DEV_OPS_DATABASE_STATUS_TERMINATING - _db_name: Optional[str] = None + last_status_seen: str | None = DEV_OPS_DATABASE_STATUS_TERMINATING + _db_name: str | None = None while last_status_seen == DEV_OPS_DATABASE_STATUS_TERMINATING: logger.info(f"sleeping to poll for status of '{id}', async") await asyncio.sleep(DEV_OPS_DATABASE_POLL_INTERVAL_S) @@ -1289,11 +1289,11 @@ async def async_drop_database( def get_database_admin( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - region: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + region: str | None = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin object for admin work within a certain database. @@ -1342,16 +1342,16 @@ def get_database_admin( def get_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> Database: """ Create a Database instance for a specific database, to be used @@ -1422,7 +1422,7 @@ def get_database( max_time_ms=max_time_ms, ) - _keyspace: Optional[str] + _keyspace: str | None if keyspace_param: _keyspace = keyspace_param else: @@ -1450,15 +1450,15 @@ def get_database( def get_async_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase instance for a specific database, to be used @@ -1492,15 +1492,15 @@ class DatabaseAdmin(ABC): """ environment: str - spawner_database: Union[Database, AsyncDatabase] + spawner_database: Database | AsyncDatabase @abstractmethod - def list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: + def list_namespaces(self, *pargs: Any, **kwargs: Any) -> list[str]: """Get a list of namespaces for the database.""" ... @abstractmethod - def list_keyspaces(self, *pargs: Any, **kwargs: Any) -> List[str]: + def list_keyspaces(self, *pargs: Any, **kwargs: Any) -> list[str]: """Get a list of keyspaces for the database.""" ... @@ -1509,10 +1509,10 @@ def create_namespace( self, name: str, *, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. """ @@ -1523,31 +1523,31 @@ def create_keyspace( self, name: str, *, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a keyspace in the database, returning {'ok': 1} if successful. """ ... @abstractmethod - def drop_namespace(self, name: str, *pargs: Any, **kwargs: Any) -> Dict[str, Any]: + def drop_namespace(self, name: str, *pargs: Any, **kwargs: Any) -> dict[str, Any]: """ Drop (delete) a namespace from the database, returning {'ok': 1} if successful. """ ... @abstractmethod - def drop_keyspace(self, name: str, *pargs: Any, **kwargs: Any) -> Dict[str, Any]: + def drop_keyspace(self, name: str, *pargs: Any, **kwargs: Any) -> dict[str, Any]: """ Drop (delete) a keyspace from the database, returning {'ok': 1} if successful. """ ... @abstractmethod - async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: + async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> list[str]: """ Get a list of namespaces for the database. (Async version of the method.) @@ -1555,7 +1555,7 @@ async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: ... @abstractmethod - async def async_list_keyspaces(self, *pargs: Any, **kwargs: Any) -> List[str]: + async def async_list_keyspaces(self, *pargs: Any, **kwargs: Any) -> list[str]: """ Get a list of keyspaces for the database. (Async version of the method.) @@ -1567,10 +1567,10 @@ async def async_create_namespace( self, name: str, *, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. (Async version of the method.) @@ -1582,10 +1582,10 @@ async def async_create_keyspace( self, name: str, *, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a keyspace in the database, returning {'ok': 1} if successful. (Async version of the method.) @@ -1595,7 +1595,7 @@ async def async_create_keyspace( @abstractmethod async def async_drop_namespace( self, name: str, *pargs: Any, **kwargs: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Drop (delete) a namespace from the database, returning {'ok': 1} if successful. (Async version of the method.) @@ -1605,7 +1605,7 @@ async def async_drop_namespace( @abstractmethod async def async_drop_keyspace( self, name: str, *pargs: Any, **kwargs: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Drop (delete) a keyspace from the database, returning {'ok': 1} if successful. (Async version of the method.) @@ -1707,20 +1707,20 @@ class is created by a method such as `Database.get_database_admin()`, def __init__( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - region: Optional[str] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - spawner_database: Optional[Union[Database, AsyncDatabase]] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + region: str | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + spawner_database: Database | AsyncDatabase | None = None, + max_time_ms: int | None = None, ) -> None: # lazy import here to avoid circular dependency from astrapy.database import Database @@ -1793,7 +1793,7 @@ def __init__( if dev_ops_api_version is not None else DEV_OPS_VERSION_ENV_MAP[self.environment] ).strip("/") - self._dev_ops_commander_headers: Dict[str, str | None] + self._dev_ops_commander_headers: dict[str, str | None] if self.token_provider: _token = self.token_provider.get_token() self._dev_ops_commander_headers = { @@ -1815,12 +1815,12 @@ def __init__( def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - env_desc: Optional[str] + env_desc: str | None if self.environment == Environment.PROD: env_desc = None else: @@ -1879,16 +1879,16 @@ def _get_dev_ops_api_commander(self) -> APICommander: def _copy( self, - id: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - region: Optional[str] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + id: str | None = None, + token: str | TokenProvider | None = None, + region: str | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AstraDBDatabaseAdmin: return AstraDBDatabaseAdmin( id=id or self._database_id, @@ -1906,10 +1906,10 @@ def _copy( def with_options( self, *, - id: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + id: str | None = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDBDatabaseAdmin: """ Create a clone of this AstraDBDatabaseAdmin with some changed attributes. @@ -1942,8 +1942,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -1995,9 +1995,9 @@ def region(self) -> str: def from_astra_db_admin( id: str, *, - region: Optional[str], + region: str | None, astra_db_admin: AstraDBAdmin, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin from an AstraDBAdmin and a database ID. @@ -2053,11 +2053,11 @@ def from_astra_db_admin( def from_api_endpoint( api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin from an API Endpoint and optionally a token. @@ -2113,7 +2113,7 @@ def from_api_endpoint( msg = api_endpoint_parsing_error_message(api_endpoint) raise ValueError(msg) - def info(self, *, max_time_ms: Optional[int] = None) -> AdminDatabaseInfo: + def info(self, *, max_time_ms: int | None = None) -> AdminDatabaseInfo: """ Query the DevOps API for the full info on this database. @@ -2140,7 +2140,7 @@ def info(self, *, max_time_ms: Optional[int] = None) -> AdminDatabaseInfo: return req_response async def async_info( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> AdminDatabaseInfo: """ Query the DevOps API for the full info on this database. @@ -2176,7 +2176,7 @@ async def async_info( current_version=__version__, details=NAMESPACE_DEPRECATION_NOTICE_METHOD, ) - def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: + def list_namespaces(self, *, max_time_ms: int | None = None) -> list[str]: """ Query the DevOps API for a list of the namespaces in the database. @@ -2195,7 +2195,7 @@ def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: return self.list_keyspaces(max_time_ms=max_time_ms) - def list_keyspaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: + def list_keyspaces(self, *, max_time_ms: int | None = None) -> list[str]: """ Query the DevOps API for a list of the keyspaces in the database. @@ -2225,8 +2225,8 @@ def list_keyspaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: details=NAMESPACE_DEPRECATION_NOTICE_METHOD, ) async def async_list_namespaces( - self, *, max_time_ms: Optional[int] = None - ) -> List[str]: + self, *, max_time_ms: int | None = None + ) -> list[str]: """ Query the DevOps API for a list of the namespaces in the database. Async version of the method, for use in an asyncio context. @@ -2255,8 +2255,8 @@ async def async_list_namespaces( return await self.async_list_keyspaces(max_time_ms=max_time_ms) async def async_list_keyspaces( - self, *, max_time_ms: Optional[int] = None - ) -> List[str]: + self, *, max_time_ms: int | None = None + ) -> list[str]: """ Query the DevOps API for a list of the keyspaces in the database. Async version of the method, for use in an asyncio context. @@ -2299,11 +2299,11 @@ def create_namespace( name: str, *, wait_until_active: bool = True, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in this database as requested, optionally waiting for it to be ready. @@ -2357,11 +2357,11 @@ def create_keyspace( name: str, *, wait_until_active: bool = True, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a keyspace in this database as requested, optionally waiting for it to be ready. @@ -2459,11 +2459,11 @@ async def async_create_namespace( name: str, *, wait_until_active: bool = True, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in this database as requested, optionally waiting for it to be ready. @@ -2516,11 +2516,11 @@ async def async_create_keyspace( name: str, *, wait_until_active: bool = True, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a keyspace in this database as requested, optionally waiting for it to be ready. @@ -2621,8 +2621,8 @@ def drop_namespace( name: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Delete a namespace from the database, optionally waiting for the database to become active again. @@ -2667,8 +2667,8 @@ def drop_keyspace( name: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Delete a keyspace from the database, optionally waiting for the database to become active again. @@ -2753,8 +2753,8 @@ async def async_drop_namespace( name: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Delete a namespace from the database, optionally waiting for the database to become active again. @@ -2798,8 +2798,8 @@ async def async_drop_keyspace( name: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Delete a keyspace from the database, optionally waiting for the database to become active again. @@ -2880,8 +2880,8 @@ def drop( self, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop this database, i.e. delete it completely and permanently with all its data. @@ -2931,8 +2931,8 @@ async def async_drop( self, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop this database, i.e. delete it completely and permanently with all its data. Async version of the method, for use in an asyncio context. @@ -2979,13 +2979,13 @@ async def async_drop( def get_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> Database: """ Create a Database instance from this database admin, for data-related tasks. @@ -3049,13 +3049,13 @@ def get_database( def get_async_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase instance out of this class for working @@ -3080,7 +3080,7 @@ def get_async_database( ).to_async() def find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. @@ -3123,7 +3123,7 @@ def find_embedding_providers( return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) async def async_find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. @@ -3228,13 +3228,13 @@ def __init__( self, api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - spawner_database: Optional[Union[Database, AsyncDatabase]] = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + spawner_database: Database | AsyncDatabase | None = None, ) -> None: # lazy import here to avoid circular dependency from astrapy.database import Database @@ -3269,7 +3269,7 @@ def __init__( def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: @@ -3301,13 +3301,13 @@ def _get_api_commander(self) -> APICommander: def _copy( self, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIDatabaseAdmin: return DataAPIDatabaseAdmin( api_endpoint=api_endpoint or self.api_endpoint, @@ -3322,10 +3322,10 @@ def _copy( def with_options( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIDatabaseAdmin: """ Create a clone of this DataAPIDatabaseAdmin with some changed attributes. @@ -3358,8 +3358,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -3390,7 +3390,7 @@ def set_caller( current_version=__version__, details=NAMESPACE_DEPRECATION_NOTICE_METHOD, ) - def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: + def list_namespaces(self, *, max_time_ms: int | None = None) -> list[str]: """ Query the API for a list of the namespaces in the database. @@ -3420,7 +3420,7 @@ def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: logger.info("finished getting list of namespaces") return fn_response["status"]["namespaces"] # type: ignore[no-any-return] - def list_keyspaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: + def list_keyspaces(self, *, max_time_ms: int | None = None) -> list[str]: """ Query the API for a list of the keyspaces in the database. @@ -3458,12 +3458,12 @@ def create_namespace( self, name: str, *, - replication_options: Optional[Dict[str, Any]] = None, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + replication_options: dict[str, Any] | None = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. @@ -3538,12 +3538,12 @@ def create_keyspace( self, name: str, *, - replication_options: Optional[Dict[str, Any]] = None, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + replication_options: dict[str, Any] | None = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a keyspace in the database, returning {'ok': 1} if successful. @@ -3622,8 +3622,8 @@ def drop_namespace( self, name: str, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop (delete) a namespace from the database. @@ -3667,8 +3667,8 @@ def drop_keyspace( self, name: str, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop (delete) a keyspace from the database. @@ -3713,8 +3713,8 @@ def drop_keyspace( details=NAMESPACE_DEPRECATION_NOTICE_METHOD, ) async def async_list_namespaces( - self, *, max_time_ms: Optional[int] = None - ) -> List[str]: + self, *, max_time_ms: int | None = None + ) -> list[str]: """ Query the API for a list of the namespaces in the database. Async version of the method, for use in an asyncio context. @@ -3746,8 +3746,8 @@ async def async_list_namespaces( return fn_response["status"]["namespaces"] # type: ignore[no-any-return] async def async_list_keyspaces( - self, *, max_time_ms: Optional[int] = None - ) -> List[str]: + self, *, max_time_ms: int | None = None + ) -> list[str]: """ Query the API for a list of the keyspaces in the database. Async version of the method, for use in an asyncio context. @@ -3786,12 +3786,12 @@ async def async_create_namespace( self, name: str, *, - replication_options: Optional[Dict[str, Any]] = None, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + replication_options: dict[str, Any] | None = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. Async version of the method, for use in an asyncio context. @@ -3869,12 +3869,12 @@ async def async_create_keyspace( self, name: str, *, - replication_options: Optional[Dict[str, Any]] = None, - update_db_keyspace: Optional[bool] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + replication_options: dict[str, Any] | None = None, + update_db_keyspace: bool | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a keyspace in the database, returning {'ok': 1} if successful. Async version of the method, for use in an asyncio context. @@ -3956,8 +3956,8 @@ async def async_drop_namespace( self, name: str, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop (delete) a namespace from the database. Async version of the method, for use in an asyncio context. @@ -4004,8 +4004,8 @@ async def async_drop_keyspace( self, name: str, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop (delete) a keyspace from the database. Async version of the method, for use in an asyncio context. @@ -4049,11 +4049,11 @@ async def async_drop_keyspace( def get_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: """ Create a Database instance out of this class for working with the data in it. @@ -4108,11 +4108,11 @@ def get_database( def get_async_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase instance for the database, to be used @@ -4135,7 +4135,7 @@ def get_async_database( ).to_async() def find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. @@ -4178,7 +4178,7 @@ def find_embedding_providers( return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) async def async_find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. diff --git a/astrapy/api_commander.py b/astrapy/api_commander.py index d3865446..2c5b2a9a 100644 --- a/astrapy/api_commander.py +++ b/astrapy/api_commander.py @@ -17,18 +17,7 @@ import json import logging from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Dict, Iterable, cast import httpx @@ -77,8 +66,8 @@ def __init__( self, api_endpoint: str, path: str, - headers: Dict[str, Union[str, None]] = {}, - callers: List[Tuple[Optional[str], Optional[str]]] = [], + headers: dict[str, str | None] = {}, + callers: list[tuple[str | None, str | None]] = [], redacted_header_names: Iterable[str] = DEFAULT_REDACTED_HEADER_NAMES, dev_ops_api: bool = False, ) -> None: @@ -90,15 +79,14 @@ def __init__( self.redacted_header_names = set(redacted_header_names) self.dev_ops_api = dev_ops_api - self._faulty_response_exc_class: Union[ - Type[DevOpsAPIFaultyResponseException], Type[DataAPIFaultyResponseException] - ] - self._response_exc_class: Union[ - Type[DevOpsAPIResponseException], Type[DataAPIResponseException] - ] - self._http_exc_class: Union[ - Type[DataAPIHttpException], Type[DevOpsAPIHttpException] - ] + self._faulty_response_exc_class: ( + type[DevOpsAPIFaultyResponseException] + | type[DataAPIFaultyResponseException] + ) + self._response_exc_class: ( + type[DevOpsAPIResponseException] | type[DataAPIResponseException] + ) + self._http_exc_class: type[DataAPIHttpException] | type[DevOpsAPIHttpException] if self.dev_ops_api: self._faulty_response_exc_class = DevOpsAPIFaultyResponseException self._response_exc_class = DevOpsAPIResponseException @@ -112,10 +100,10 @@ def __init__( full_user_agent_string = compose_full_user_agent( [user_agent_ragstack] + self.callers + [user_agent_astrapy] ) - self.caller_header: Dict[str, str] = ( + self.caller_header: dict[str, str] = ( {"User-Agent": full_user_agent_string} if full_user_agent_string else {} ) - self.full_headers: Dict[str, str] = { + self.full_headers: dict[str, str] = { **{k: v for k, v in self.headers.items() if v is not None}, **self.caller_header, **{"Content-Type": "application/json"}, @@ -146,20 +134,20 @@ async def __aenter__(self) -> APICommander: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: await self.async_client.aclose() def _copy( self, - api_endpoint: Optional[str] = None, - path: Optional[str] = None, - headers: Optional[Dict[str, Union[str, None]]] = None, - callers: Optional[List[Tuple[Optional[str], Optional[str]]]] = None, - redacted_header_names: Optional[List[str]] = None, - dev_ops_api: Optional[bool] = None, + api_endpoint: str | None = None, + path: str | None = None, + headers: dict[str, str | None] | None = None, + callers: list[tuple[str | None, str | None]] | None = None, + redacted_header_names: list[str] | None = None, + dev_ops_api: bool | None = None, ) -> APICommander: # some care in allowing e.g. {} to override (but not None): return APICommander( @@ -181,10 +169,10 @@ def _raw_response_to_json( self, raw_response: httpx.Response, raise_api_errors: bool, - payload: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: + payload: dict[str, Any] | None, + ) -> dict[str, Any]: # try to process the httpx raw response into a JSON or throw a failure - raw_response_json: Dict[str, Any] + raw_response_json: dict[str, Any] try: raw_response_json = cast( Dict[str, Any], @@ -214,7 +202,7 @@ def _raw_response_to_json( # no warnings check for DevOps API (there, 'status' may contain a string) if not self.dev_ops_api: - warning_messages: List[str] = (raw_response_json.get("status") or {}).get( + warning_messages: list[str] = (raw_response_json.get("status") or {}).get( "warnings" ) or [] if warning_messages: @@ -226,15 +214,15 @@ def _raw_response_to_json( response_json = restore_from_api(raw_response_json) return response_json - def _compose_request_url(self, additional_path: Optional[str]) -> str: + def _compose_request_url(self, additional_path: str | None) -> str: if additional_path: return "/".join([self.full_path.rstrip("/"), additional_path.lstrip("/")]) else: return self.full_path def _encode_payload( - self, normalized_payload: Optional[Dict[str, Any]] - ) -> Optional[bytes]: + self, normalized_payload: dict[str, Any] | None + ) -> bytes | None: if normalized_payload is not None: return json.dumps( normalized_payload, @@ -248,8 +236,8 @@ def raw_request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: @@ -290,8 +278,8 @@ async def async_raw_request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: @@ -332,11 +320,11 @@ def request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: raw_response = self.raw_request( http_method=http_method, payload=payload, @@ -352,11 +340,11 @@ async def async_request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: raw_response = await self.async_raw_request( http_method=http_method, payload=payload, diff --git a/astrapy/api_options.py b/astrapy/api_options.py index 5b84179b..d0f0dc7e 100644 --- a/astrapy/api_options.py +++ b/astrapy/api_options.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional, TypeVar +from typing import TypeVar from astrapy.authentication import ( EmbeddingAPIKeyHeaderProvider, @@ -41,9 +41,9 @@ class BaseAPIOptions: much sense. """ - max_time_ms: Optional[int] = None + max_time_ms: int | None = None - def with_default(self: AO, default: Optional[BaseAPIOptions]) -> AO: + def with_default(self: AO, default: BaseAPIOptions | None) -> AO: """ Return a new instance created by completing this instance with a default API options object. @@ -70,7 +70,7 @@ def with_default(self: AO, default: Optional[BaseAPIOptions]) -> AO: else: return self - def with_override(self: AO, override: Optional[BaseAPIOptions]) -> AO: + def with_override(self: AO, override: BaseAPIOptions | None) -> AO: """ Return a new instance created by overriding the members of this instance with those taken from a supplied "override" API options object. diff --git a/astrapy/authentication.py b/astrapy/authentication.py index fbe0776d..8b009f43 100644 --- a/astrapy/authentication.py +++ b/astrapy/authentication.py @@ -16,7 +16,7 @@ import base64 from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Union +from typing import Any from astrapy.defaults import ( EMBEDDING_HEADER_API_KEY, @@ -28,7 +28,7 @@ ) -def coerce_token_provider(token: Optional[Union[str, TokenProvider]]) -> TokenProvider: +def coerce_token_provider(token: str | TokenProvider | None) -> TokenProvider: if isinstance(token, TokenProvider): return token else: @@ -36,7 +36,7 @@ def coerce_token_provider(token: Optional[Union[str, TokenProvider]]) -> TokenPr def coerce_embedding_headers_provider( - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]], + embedding_api_key: str | EmbeddingHeadersProvider | None, ) -> EmbeddingHeadersProvider: if isinstance(embedding_api_key, EmbeddingHeadersProvider): return embedding_api_key @@ -121,7 +121,7 @@ def __bool__(self) -> bool: return self.get_token() is not None @abstractmethod - def get_token(self) -> Union[str, None]: + def get_token(self) -> str | None: """ Produce a string for direct use as token in a subsequent API request, or None for no token. @@ -146,7 +146,7 @@ class StaticTokenProvider(TokenProvider): ... ) """ - def __init__(self, token: Union[str, None]) -> None: + def __init__(self, token: str | None) -> None: self.token = token def __repr__(self) -> str: @@ -155,7 +155,7 @@ def __repr__(self) -> str: else: return self.token - def get_token(self) -> Union[str, None]: + def get_token(self) -> str | None: return self.token @@ -230,7 +230,7 @@ def __bool__(self) -> bool: return self.get_headers() != {} @abstractmethod - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: """ Produce a dictionary for use as (part of) the headers in HTTP requests to the Data API. @@ -277,7 +277,7 @@ class EmbeddingAPIKeyHeaderProvider(EmbeddingHeadersProvider): ... ) """ - def __init__(self, embedding_api_key: Optional[str]) -> None: + def __init__(self, embedding_api_key: str | None) -> None: self.embedding_api_key = embedding_api_key def __repr__(self) -> str: @@ -286,7 +286,7 @@ def __repr__(self) -> str: else: return f'{self.__class__.__name__}("{redact_secret(self.embedding_api_key, 8)}")' - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: if self.embedding_api_key is not None: return {EMBEDDING_HEADER_API_KEY: self.embedding_api_key} else: @@ -347,7 +347,7 @@ def __repr__(self) -> str: f'embedding_secret_id="{redact_secret(self.embedding_secret_id, 6)}")' ) - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: return { EMBEDDING_HEADER_AWS_ACCESS_ID: self.embedding_access_id, EMBEDDING_HEADER_AWS_SECRET_ID: self.embedding_secret_id, diff --git a/astrapy/client.py b/astrapy/client.py index 82f6e982..18cdf7f0 100644 --- a/astrapy/client.py +++ b/astrapy/client.py @@ -16,7 +16,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from astrapy.admin import ( api_endpoint_parser, @@ -81,11 +81,11 @@ class DataAPIClient: def __init__( self, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.token_provider = coerce_token_provider(token) self.environment = (environment or Environment.PROD).lower() @@ -97,12 +97,12 @@ def __init__( self._caller_version = caller_version def __repr__(self) -> str: - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'"{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - env_desc: Optional[str] + env_desc: str | None if self.environment == Environment.PROD: env_desc = None else: @@ -140,10 +140,10 @@ def __getitem__(self, database_id_or_api_endpoint: str) -> Database: def _copy( self, *, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIClient: return DataAPIClient( token=coerce_token_provider(token) or self.token_provider, @@ -155,9 +155,9 @@ def _copy( def with_options( self, *, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIClient: """ Create a clone of this DataAPIClient with some changed attributes. @@ -189,8 +189,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -213,16 +213,16 @@ def set_caller( def get_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> Database: """ Get a Database object from this client, for doing data-related work. @@ -349,16 +349,16 @@ def get_database( def get_async_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> AsyncDatabase: """ Get an AsyncDatabase object from this client. @@ -386,11 +386,11 @@ def get_database_by_api_endpoint( self, api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: """ Get a Database object from this client, for doing data-related work. @@ -489,11 +489,11 @@ def get_async_database_by_api_endpoint( self, api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Get an AsyncDatabase object from this client, for doing data-related work. @@ -522,9 +522,9 @@ def get_async_database_by_api_endpoint( def get_admin( self, *, - token: Optional[Union[str, TokenProvider]] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> AstraDBAdmin: """ Get an AstraDBAdmin instance corresponding to this client, for diff --git a/astrapy/collection.py b/astrapy/collection.py index 5120f9e0..1a6e566e 100644 --- a/astrapy/collection.py +++ b/astrapy/collection.py @@ -20,17 +20,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Iterable import deprecation @@ -87,7 +77,7 @@ logger = logging.getLogger(__name__) -def _prepare_update_info(statuses: List[Dict[str, Any]]) -> Dict[str, Any]: +def _prepare_update_info(statuses: list[dict[str, Any]]) -> dict[str, Any]: reduced_status = { "matchedCount": sum( status["matchedCount"] for status in statuses if "matchedCount" in status @@ -118,11 +108,11 @@ def _prepare_update_info(statuses: List[Dict[str, Any]]) -> Dict[str, Any]: def _collate_vector_to_sort( - sort: Optional[SortType], - vector: Optional[VectorType], - vectorize: Optional[str], -) -> Optional[SortType]: - _vsort: Dict[str, Any] + sort: SortType | None, + vector: VectorType | None, + vectorize: str | None, +) -> SortType | None: + _vsort: dict[str, Any] if vector is None: if vectorize is None: return sort @@ -149,7 +139,7 @@ def _collate_vector_to_sort( ) -def _is_vector_sort(sort: Optional[SortType]) -> bool: +def _is_vector_sort(sort: SortType | None) -> bool: if sort is None: return False else: @@ -157,7 +147,7 @@ def _is_vector_sort(sort: Optional[SortType]) -> bool: def _collate_vector_to_document( - document0: DocumentType, vector: Optional[VectorType], vectorize: Optional[str] + document0: DocumentType, vector: VectorType | None, vectorize: str | None ) -> DocumentType: if vector is None: if vectorize is None: @@ -193,9 +183,9 @@ def _collate_vector_to_document( def _collate_vectors_to_documents( documents: Iterable[DocumentType], - vectors: Optional[Iterable[Optional[VectorType]]], - vectorize: Optional[Iterable[Optional[str]]], -) -> List[DocumentType]: + vectors: Iterable[VectorType | None] | None, + vectorize: Iterable[str | None] | None, +) -> list[DocumentType]: if vectors is None and vectorize is None: return list(documents) else: @@ -267,11 +257,11 @@ def __init__( database: Database, name: str, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -359,13 +349,13 @@ def _get_api_commander(self) -> APICommander: def _copy( self, *, - database: Optional[Database] = None, - name: Optional[str] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: Database | None = None, + name: str | None = None, + keyspace: str | None = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Collection: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -383,11 +373,11 @@ def _copy( def with_options( self, *, - name: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + name: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Collection: """ Create a clone of this collection with some changed attributes. @@ -442,14 +432,14 @@ def with_options( def to_async( self, *, - database: Optional[AsyncDatabase] = None, - name: Optional[str] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: AsyncDatabase | None = None, + name: str | None = None, + keyspace: str | None = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncCollection: """ Create an AsyncCollection from this one. Save for the arguments @@ -514,8 +504,8 @@ def to_async( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -535,7 +525,7 @@ def set_caller( self.caller_version = caller_version or self.caller_version self._api_commander = self._get_api_commander() - def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions: + def options(self, *, max_time_ms: int | None = None) -> CollectionOptions: """ Get the collection options, i.e. its configuration as read from the database. @@ -681,9 +671,9 @@ def insert_one( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + max_time_ms: int | None = None, ) -> InsertOneResult: """ Insert a single document in the collection in an atomic operation. @@ -769,12 +759,12 @@ def insert_many( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = False, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> InsertManyResult: """ Insert a list of documents into the collection. @@ -883,11 +873,11 @@ def insert_many( _documents = _collate_vectors_to_documents(documents, vectors, vectorize) _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"inserting {len(_documents)} documents in '{self.name}'") - raw_results: List[Dict[str, Any]] = [] + raw_results: list[dict[str, Any]] = [] timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: options = {"ordered": True} - inserted_ids: List[Any] = [] + inserted_ids: list[Any] = [] for i in range(0, len(_documents), _chunk_size): im_payload = { "insertMany": { @@ -937,8 +927,8 @@ def insert_many( with ThreadPoolExecutor(max_workers=_concurrency) as executor: def _chunk_insertor( - document_chunk: List[Dict[str, Any]] - ) -> Dict[str, Any]: + document_chunk: list[dict[str, Any]] + ) -> dict[str, Any]: im_payload = { "insertMany": { "documents": document_chunk, @@ -1014,17 +1004,17 @@ def _chunk_insertor( def find( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + projection: ProjectionType | None = None, + skip: int | None = None, + limit: int | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> Cursor: """ Find documents on the collection, matching a certain provided filter. @@ -1227,15 +1217,15 @@ def find( def find_one( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Run a search, returning the first document in the collection that matches provided filters, if any is found. @@ -1336,9 +1326,9 @@ def distinct( self, key: str, *, - filter: Optional[FilterType] = None, - max_time_ms: Optional[int] = None, - ) -> List[Any]: + filter: FilterType | None = None, + max_time_ms: int | None = None, + ) -> list[Any]: """ Return a list of the unique values of `key` across the documents in the collection that match the provided filter. @@ -1416,7 +1406,7 @@ def count_documents( filter: FilterType, *, upper_bound: int, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Count the documents in the collection matching the specified filter. @@ -1496,7 +1486,7 @@ def count_documents( def estimated_document_count( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Query the API server for an estimate of the document count in the collection. @@ -1515,7 +1505,7 @@ def estimated_document_count( 35700 """ _max_time_ms = max_time_ms or self.api_options.max_time_ms - ed_payload: Dict[str, Any] = {"estimatedDocumentCount": {}} + ed_payload: dict[str, Any] = {"estimatedDocumentCount": {}} logger.info(f"estimatedDocumentCount on '{self.name}'") ed_response = self._api_commander.request( payload=ed_payload, @@ -1536,14 +1526,14 @@ def find_one_and_replace( filter: FilterType, replacement: DocumentType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and replace it entirely with a new one, optionally inserting a new one if no match is found. @@ -1684,11 +1674,11 @@ def replace_one( filter: FilterType, replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Replace a single document on the collection with a new one, @@ -1785,16 +1775,16 @@ def replace_one( def find_one_and_update( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and update it as requested, optionally inserting a new one if no match is found. @@ -1939,13 +1929,13 @@ def find_one_and_update( def update_one( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Update a single document on the collection as requested, @@ -2046,10 +2036,10 @@ def update_one( def update_many( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Apply an update operations to all documents matching a condition, @@ -2109,9 +2099,9 @@ def update_many( api_options = { "upsert": upsert, } - page_state_options: Dict[str, str] = {} - um_responses: List[Dict[str, Any]] = [] - um_statuses: List[Dict[str, Any]] = [] + page_state_options: dict[str, str] = {} + um_responses: list[dict[str, Any]] = [] + um_statuses: list[dict[str, Any]] = [] must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"starting update_many on '{self.name}'") @@ -2177,12 +2167,12 @@ def find_one_and_delete( self, filter: FilterType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document in the collection and delete it. The deleted document, however, is the return value of the method. @@ -2294,10 +2284,10 @@ def delete_one( self, filter: FilterType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete one document matching a provided filter. @@ -2399,7 +2389,7 @@ def delete_many( self, filter: FilterType, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete all documents matching a provided filter. @@ -2443,7 +2433,7 @@ def delete_many( collection is devoid of matches. An exception is the `filter={}` case, whereby the operation is atomic. """ - dm_responses: List[Dict[str, Any]] = [] + dm_responses: list[dict[str, Any]] = [] deleted_count = 0 must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms @@ -2493,7 +2483,7 @@ def delete_many( current_version=__version__, details="Use delete_many with filter={} instead.", ) - def delete_all(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + def delete_all(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Delete all documents in a collection. @@ -2540,8 +2530,8 @@ def bulk_write( requests: Iterable[BaseOperation], *, ordered: bool = False, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> BulkWriteResult: """ Execute an arbitrary amount of operations such as inserts, updates, deletes @@ -2605,7 +2595,7 @@ def bulk_write( logger.info(f"startng a bulk write on '{self.name}'") timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: - bulk_write_results: List[BulkWriteResult] = [] + bulk_write_results: list[BulkWriteResult] = [] for operation_i, operation in enumerate(requests): try: this_bw_result = operation.execute( @@ -2652,7 +2642,7 @@ def bulk_write( def _execute_as_either( operation: BaseOperation, operation_i: int - ) -> Tuple[Optional[BulkWriteResult], Optional[DataAPIResponseException]]: + ) -> tuple[BulkWriteResult | None, DataAPIResponseException | None]: try: ex_result = operation.execute( self, @@ -2712,7 +2702,7 @@ def _execute_as_either( logger.info(f"finished a bulk write on '{self.name}'") return reduce_bulk_write_results(bulk_write_successes) - def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + def drop(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Drop the collection, i.e. delete it from the database along with all the documents it contains. @@ -2755,11 +2745,11 @@ def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this collection with an arbitrary, caller-provided payload. @@ -2843,11 +2833,11 @@ def __init__( database: AsyncDatabase, name: str, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -2939,9 +2929,9 @@ async def __aenter__(self) -> AsyncCollection: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: if self._api_commander is not None: await self._api_commander.__aexit__( @@ -2953,13 +2943,13 @@ async def __aexit__( def _copy( self, *, - database: Optional[AsyncDatabase] = None, - name: Optional[str] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: AsyncDatabase | None = None, + name: str | None = None, + keyspace: str | None = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncCollection: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -2977,11 +2967,11 @@ def _copy( def with_options( self, *, - name: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + name: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncCollection: """ Create a clone of this collection with some changed attributes. @@ -3036,14 +3026,14 @@ def with_options( def to_sync( self, *, - database: Optional[Database] = None, - name: Optional[str] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: Database | None = None, + name: str | None = None, + keyspace: str | None = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Collection: """ Create a Collection from this one. Save for the arguments @@ -3108,8 +3098,8 @@ def to_sync( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -3129,7 +3119,7 @@ def set_caller( self.caller_version = caller_version or self.caller_version self._api_commander = self._get_api_commander() - async def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions: + async def options(self, *, max_time_ms: int | None = None) -> CollectionOptions: """ Get the collection options, i.e. its configuration as read from the database. @@ -3277,9 +3267,9 @@ async def insert_one( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + max_time_ms: int | None = None, ) -> InsertOneResult: """ Insert a single document in the collection in an atomic operation. @@ -3371,12 +3361,12 @@ async def insert_many( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = False, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> InsertManyResult: """ Insert a list of documents into the collection. @@ -3497,11 +3487,11 @@ async def insert_many( _documents = _collate_vectors_to_documents(documents, vectors, vectorize) _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"inserting {len(_documents)} documents in '{self.name}'") - raw_results: List[Dict[str, Any]] = [] + raw_results: list[dict[str, Any]] = [] timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: options = {"ordered": True} - inserted_ids: List[Any] = [] + inserted_ids: list[Any] = [] for i in range(0, len(_documents), _chunk_size): im_payload = { "insertMany": { @@ -3551,8 +3541,8 @@ async def insert_many( sem = asyncio.Semaphore(_concurrency) async def concurrent_insert_chunk( - document_chunk: List[DocumentType], - ) -> Dict[str, Any]: + document_chunk: list[DocumentType], + ) -> dict[str, Any]: async with sem: im_payload = { "insertMany": { @@ -3618,17 +3608,17 @@ async def concurrent_insert_chunk( def find( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + projection: ProjectionType | None = None, + skip: int | None = None, + limit: int | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> AsyncCursor: """ Find documents on the collection, matching a certain provided filter. @@ -3841,15 +3831,15 @@ def find( async def find_one( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Run a search, returning the first document in the collection that matches provided filters, if any is found. @@ -3969,9 +3959,9 @@ async def distinct( self, key: str, *, - filter: Optional[FilterType] = None, - max_time_ms: Optional[int] = None, - ) -> List[Any]: + filter: FilterType | None = None, + max_time_ms: int | None = None, + ) -> list[Any]: """ Return a list of the unique values of `key` across the documents in the collection that match the provided filter. @@ -4057,7 +4047,7 @@ async def count_documents( filter: FilterType, *, upper_bound: int, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Count the documents in the collection matching the specified filter. @@ -4142,7 +4132,7 @@ async def count_documents( async def estimated_document_count( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Query the API server for an estimate of the document count in the collection. @@ -4161,7 +4151,7 @@ async def estimated_document_count( 35700 """ _max_time_ms = max_time_ms or self.api_options.max_time_ms - ed_payload: Dict[str, Any] = {"estimatedDocumentCount": {}} + ed_payload: dict[str, Any] = {"estimatedDocumentCount": {}} logger.info(f"estimatedDocumentCount on '{self.name}'") ed_response = await self._api_commander.async_request( payload=ed_payload, @@ -4182,14 +4172,14 @@ async def find_one_and_replace( filter: FilterType, replacement: DocumentType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and replace it entirely with a new one, optionally inserting a new one if no match is found. @@ -4339,11 +4329,11 @@ async def replace_one( filter: FilterType, replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Replace a single document on the collection with a new one, @@ -4459,16 +4449,16 @@ async def replace_one( async def find_one_and_update( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and update it as requested, optionally inserting a new one if no match is found. @@ -4622,13 +4612,13 @@ async def find_one_and_update( async def update_one( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Update a single document on the collection as requested, @@ -4747,10 +4737,10 @@ async def update_one( async def update_many( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Apply an update operations to all documents matching a condition, @@ -4821,9 +4811,9 @@ async def update_many( api_options = { "upsert": upsert, } - page_state_options: Dict[str, str] = {} - um_responses: List[Dict[str, Any]] = [] - um_statuses: List[Dict[str, Any]] = [] + page_state_options: dict[str, str] = {} + um_responses: list[dict[str, Any]] = [] + um_statuses: list[dict[str, Any]] = [] must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"starting update_many on '{self.name}'") @@ -4889,12 +4879,12 @@ async def find_one_and_delete( self, filter: FilterType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document in the collection and delete it. The deleted document, however, is the return value of the method. @@ -5013,10 +5003,10 @@ async def delete_one( self, filter: FilterType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete one document matching a provided filter. @@ -5119,7 +5109,7 @@ async def delete_many( self, filter: FilterType, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete all documents matching a provided filter. @@ -5168,7 +5158,7 @@ async def delete_many( collection is devoid of matches. An exception is the `filter={}` case, whereby the operation is atomic. """ - dm_responses: List[Dict[str, Any]] = [] + dm_responses: list[dict[str, Any]] = [] deleted_count = 0 must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms @@ -5218,7 +5208,7 @@ async def delete_many( current_version=__version__, details="Use delete_many with filter={} instead.", ) - async def delete_all(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + async def delete_all(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Delete all documents in a collection. @@ -5272,8 +5262,8 @@ async def bulk_write( requests: Iterable[AsyncBaseOperation], *, ordered: bool = False, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> BulkWriteResult: """ Execute an arbitrary amount of operations such as inserts, updates, deletes @@ -5353,7 +5343,7 @@ async def bulk_write( logger.info(f"startng a bulk write on '{self.name}'") timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: - bulk_write_results: List[BulkWriteResult] = [] + bulk_write_results: list[BulkWriteResult] = [] for operation_i, operation in enumerate(requests): try: this_bw_result = await operation.execute( @@ -5402,7 +5392,7 @@ async def bulk_write( async def _concurrent_execute_as_either( operation: AsyncBaseOperation, operation_i: int - ) -> Tuple[Optional[BulkWriteResult], Optional[DataAPIResponseException]]: + ) -> tuple[BulkWriteResult | None, DataAPIResponseException | None]: async with sem: try: ex_result = await operation.execute( @@ -5453,7 +5443,7 @@ async def _concurrent_execute_as_either( logger.info(f"finished a bulk write on '{self.name}'") return reduce_bulk_write_results(bulk_write_successes) - async def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + async def drop(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Drop the collection, i.e. delete it from the database along with all the documents it contains. @@ -5503,11 +5493,11 @@ async def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: async def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this collection with an arbitrary, caller-provided payload. diff --git a/astrapy/constants.py b/astrapy/constants.py index 9a881805..84143b8d 100644 --- a/astrapy/constants.py +++ b/astrapy/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Union from astrapy.defaults import ( DATA_API_ENVIRONMENT_CASSANDRA, @@ -40,8 +40,8 @@ def normalize_optional_projection( - projection: Optional[ProjectionType], -) -> Optional[Dict[str, Union[bool, Dict[str, Union[int, Iterable[int]]]]]]: + projection: ProjectionType | None, +) -> dict[str, bool | dict[str, int | Iterable[int]]] | None: if projection: if isinstance(projection, dict): # already a dictionary diff --git a/astrapy/core/api.py b/astrapy/core/api.py index eea793c3..1de8d642 100644 --- a/astrapy/core/api.py +++ b/astrapy/core/api.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, Union, cast +from typing import Any, cast import httpx @@ -27,7 +27,7 @@ class APIRequestError(ValueError): def __init__( - self, response: httpx.Response, payload: Optional[Dict[str, Any]] + self, response: httpx.Response, payload: dict[str, Any] | None ) -> None: super().__init__(response.text) @@ -42,15 +42,15 @@ def raw_api_request( client: httpx.Client, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: return make_request( client=client, @@ -71,7 +71,7 @@ def raw_api_request( def process_raw_api_response( raw_response: httpx.Response, skip_error_check: bool, - json_data: Optional[Dict[str, Any]], + json_data: dict[str, Any] | None, ) -> API_RESPONSE: # In case of other successful responses, parse the JSON body. try: @@ -95,16 +95,16 @@ def api_request( client: httpx.Client, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, skip_error_check: bool, - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> API_RESPONSE: raw_response = raw_api_request( client=client, @@ -131,15 +131,15 @@ async def async_raw_api_request( client: httpx.AsyncClient, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: return await amake_request( client=client, @@ -160,7 +160,7 @@ async def async_raw_api_request( async def async_process_raw_api_response( raw_response: httpx.Response, skip_error_check: bool, - json_data: Optional[Dict[str, Any]], + json_data: dict[str, Any] | None, ) -> API_RESPONSE: # In case of other successful responses, parse the JSON body. try: @@ -184,16 +184,16 @@ async def async_api_request( client: httpx.AsyncClient, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, skip_error_check: bool, - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> API_RESPONSE: raw_response = await async_raw_api_request( client=client, diff --git a/astrapy/core/core_types.py b/astrapy/core/core_types.py index 0b2f8c56..ffc61c87 100644 --- a/astrapy/core/core_types.py +++ b/astrapy/core/core_types.py @@ -36,9 +36,9 @@ # This is for the (partialed, if necessary) functions that can be "paginated". class PaginableRequestMethod(Protocol): - def __call__(self, options: Dict[str, Any]) -> API_RESPONSE: ... + def __call__(self, options: dict[str, Any]) -> API_RESPONSE: ... # This is for the (partialed, if necessary) async functions that can be "paginated". class AsyncPaginableRequestMethod(Protocol): - async def __call__(self, options: Dict[str, Any]) -> API_RESPONSE: ... + async def __call__(self, options: dict[str, Any]) -> API_RESPONSE: ... diff --git a/astrapy/core/db.py b/astrapy/core/db.py index 88bdbfee..723c49cb 100644 --- a/astrapy/core/db.py +++ b/astrapy/core/db.py @@ -24,18 +24,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from types import TracebackType -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import Any, Callable, Iterator, List, Union, cast import httpx @@ -71,10 +60,10 @@ def __init__( self, prefetched: int, request_method: PaginableRequestMethod, - options: Optional[Dict[str, Any]], - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + options: dict[str, Any] | None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ): - self.queue: queue.Queue[Optional[API_DOC]] = queue.Queue(prefetched) + self.queue: queue.Queue[API_DOC | None] = queue.Queue(prefetched) self.request_method = request_method self.options = options self.raw_response_callback = raw_response_callback @@ -93,8 +82,8 @@ def __iter__(self) -> Iterator[API_DOC]: @staticmethod def queue_put( - q: queue.Queue[Optional[API_DOC]], - item: Optional[API_DOC], + q: queue.Queue[API_DOC | None], + item: API_DOC | None, stop: threading.Event, ) -> None: while not stop.is_set(): @@ -139,13 +128,13 @@ class AstraDBCollection: def __init__( self, collection_name: str, - astra_db: Optional[AstraDB] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Dict[str, str] = {}, + astra_db: AstraDB | None = None, + token: str | None = None, + api_endpoint: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] = {}, ) -> None: """ Initialize an AstraDBCollection instance. @@ -188,8 +177,8 @@ def __init__( # Set the remaining instance attributes self.astra_db = astra_db - self.caller_name: Optional[str] = self.astra_db.caller_name - self.caller_version: Optional[str] = self.astra_db.caller_version + self.caller_name: str | None = self.astra_db.caller_name + self.caller_version: str | None = self.astra_db.caller_version self.additional_headers = additional_headers self.collection_name = collection_name self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" @@ -214,15 +203,15 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - collection_name: Optional[str] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Optional[Dict[str, str]] = None, + collection_name: str | None = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] | None = None, ) -> AstraDBCollection: return AstraDBCollection( collection_name=collection_name or self.collection_name, @@ -251,8 +240,8 @@ def to_async(self) -> AsyncAstraDBCollection: def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.astra_db.set_caller( caller_name=caller_name, @@ -264,9 +253,9 @@ def set_caller( def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -289,7 +278,7 @@ def _request( return response def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return self._request( method=http_methods.POST, @@ -300,10 +289,10 @@ def post_raw_request( def _get( self, - path: Optional[str] = None, - options: Optional[Dict[str, Any]] = None, + path: str | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Optional[API_RESPONSE]: + ) -> API_RESPONSE | None: full_path = f"{self.base_path}/{path}" if path else self.base_path response = self._request( method=http_methods.GET, @@ -317,8 +306,8 @@ def _get( def _put( self, - path: Optional[str] = None, - document: Optional[API_RESPONSE] = None, + path: str | None = None, + document: API_RESPONSE | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -332,8 +321,8 @@ def _put( def _post( self, - path: Optional[str] = None, - document: Optional[API_DOC] = None, + path: str | None = None, + document: API_DOC | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -346,8 +335,8 @@ def _post( return response def _recast_as_sort_projection( - self, vector: List[float], fields: Optional[List[str]] = None - ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + self, vector: list[float], fields: list[str] | None = None + ) -> tuple[dict[str, Any], dict[str, Any] | None]: """ Given a vector and optionally a list of fields, reformulate them as a sort, projection pair for regular @@ -362,7 +351,7 @@ def _recast_as_sort_projection( raise ValueError("Please use the `include_similarity` parameter") # Build the new vector parameter - sort: Dict[str, Any] = {"$vector": vector} + sort: dict[str, Any] = {"$vector": vector} # Build the new fields parameter # Note: do not leave projection={}, make it None @@ -375,8 +364,8 @@ def _recast_as_sort_projection( return sort, projection def get( - self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None - ) -> Optional[API_RESPONSE]: + self, path: str | None = None, timeout_info: TimeoutInfoWideType = None + ) -> API_RESPONSE | None: """ Retrieve a document from the collection by its path. @@ -393,10 +382,10 @@ def get( def find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -428,14 +417,14 @@ def find( def vector_find( self, - vector: List[float], + vector: list[float], *, limit: int, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> List[API_DOC]: + ) -> list[API_DOC]: """ Perform a vector-based search in the collection. @@ -480,8 +469,8 @@ def vector_find( def paginate( *, request_method: PaginableRequestMethod, - options: Optional[Dict[str, Any]], - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + options: dict[str, Any] | None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> Iterator[API_DOC]: """ Generate paginated results for a given database query method. @@ -517,13 +506,13 @@ def paginate( def paginated_find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - prefetched: Optional[int] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + prefetched: int | None = None, timeout_info: TimeoutInfoWideType = None, - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> Iterator[API_DOC]: """ Perform a paginated search in the collection. @@ -569,9 +558,9 @@ def paginated_find( def pop( self, - filter: Dict[str, Any], - pop: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + pop: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -606,9 +595,9 @@ def pop( def push( self, - filter: Dict[str, Any], - push: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + push: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -643,12 +632,12 @@ def push( def find_one_and_replace( self, - replacement: Dict[str, Any], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -686,13 +675,13 @@ def find_one_and_replace( def vector_find_one_and_replace( self, - vector: List[float], - replacement: Dict[str, Any], + vector: list[float], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and replace the first matched document. @@ -727,11 +716,11 @@ def vector_find_one_and_replace( def find_one_and_update( self, - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + update: dict[str, Any], + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -769,13 +758,13 @@ def find_one_and_update( def vector_find_one_and_update( self, - vector: List[float], - update: Dict[str, Any], + vector: list[float], + update: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and update the first matched document. @@ -811,9 +800,9 @@ def vector_find_one_and_update( def find_one_and_delete( self, - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -847,7 +836,7 @@ def find_one_and_delete( return response def count_documents( - self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None + self, filter: dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Count documents matching a given predicate (expressed as filter). @@ -875,10 +864,10 @@ def count_documents( def find_one( self, - filter: Optional[Dict[str, Any]] = {}, - projection: Optional[Dict[str, Any]] = {}, - sort: Optional[Dict[str, Any]] = {}, - options: Optional[Dict[str, Any]] = {}, + filter: dict[str, Any] | None = {}, + projection: dict[str, Any] | None = {}, + sort: dict[str, Any] | None = {}, + options: dict[str, Any] | None = {}, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -914,13 +903,13 @@ def find_one( def vector_find_one( self, - vector: List[float], + vector: list[float], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search to find a single document in the collection. @@ -986,8 +975,8 @@ def insert_one( def insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -1023,13 +1012,13 @@ def insert_many( def chunked_insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, chunk_size: int = DEFAULT_INSERT_NUM_DOCUMENTS, concurrency: int = 1, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[API_RESPONSE, Exception]]: + ) -> list[API_RESPONSE | Exception]: """ Insert multiple documents into the collection, handling chunking and optionally with concurrent insertions. @@ -1054,7 +1043,7 @@ def chunked_insert_many( This is a list of individual responses from the API: the caller will need to inspect them all, e.g. to collate the inserted IDs. """ - results: List[Union[API_RESPONSE, Exception]] = [] + results: list[API_RESPONSE | Exception] = [] # Raise a warning if ordered and concurrency if options and options.get("ordered") is True and concurrency > 1: @@ -1110,11 +1099,11 @@ def chunked_insert_many( def update_one( self, - filter: Dict[str, Any], - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, ) -> API_RESPONSE: """ Update a single document in the collection. @@ -1149,9 +1138,9 @@ def update_one( def update_many( self, - filter: Dict[str, Any], - update: Dict[str, Any], - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1204,7 +1193,7 @@ def replace( def delete_one( self, id: str, - sort: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1236,8 +1225,8 @@ def delete_one( def delete_one_by_predicate( self, - filter: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1269,7 +1258,7 @@ def delete_one_by_predicate( def delete_many( self, - filter: Dict[str, Any], + filter: dict[str, Any], skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -1304,8 +1293,8 @@ def delete_many( return response def chunked_delete_many( - self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None - ) -> List[API_RESPONSE]: + self, filter: dict[str, Any], timeout_info: TimeoutInfoWideType = None + ) -> list[API_RESPONSE]: """ Delete many documents from the collection based on a filter condition, chaining several API calls until exhaustion of the documents to delete. @@ -1436,7 +1425,7 @@ def upsert_many( concurrency: int = 1, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[str, Exception]]: + ) -> list[str | Exception]: """ Emulate an upsert operation for multiple documents in the collection. @@ -1457,7 +1446,7 @@ def upsert_many( Returns: List[Union[str, Exception]]: A list of "_id"s of the inserted or updated documents. """ - results: List[Union[str, Exception]] = [] + results: list[str | Exception] = [] # If concurrency is 1, no need for thread pool if concurrency == 1: @@ -1493,13 +1482,13 @@ class AsyncAstraDBCollection: def __init__( self, collection_name: str, - astra_db: Optional[AsyncAstraDB] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Dict[str, str] = {}, + astra_db: AsyncAstraDB | None = None, + token: str | None = None, + api_endpoint: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] = {}, ) -> None: """ Initialize an AstraDBCollection instance. @@ -1542,8 +1531,8 @@ def __init__( # Set the remaining instance attributes self.astra_db: AsyncAstraDB = astra_db - self.caller_name: Optional[str] = self.astra_db.caller_name - self.caller_version: Optional[str] = self.astra_db.caller_version + self.caller_name: str | None = self.astra_db.caller_name + self.caller_version: str | None = self.astra_db.caller_version self.additional_headers = additional_headers self.client = astra_db.client self.collection_name = collection_name @@ -1569,15 +1558,15 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - collection_name: Optional[str] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Optional[Dict[str, str]] = None, + collection_name: str | None = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] | None = None, ) -> AsyncAstraDBCollection: return AsyncAstraDBCollection( collection_name=collection_name or self.collection_name, @@ -1597,8 +1586,8 @@ def copy( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.astra_db.set_caller( caller_name=caller_name, @@ -1619,9 +1608,9 @@ def to_sync(self) -> AstraDBCollection: async def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, **kwargs: Any, @@ -1645,7 +1634,7 @@ async def _request( return response async def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return await self._request( method=http_methods.POST, @@ -1656,10 +1645,10 @@ async def post_raw_request( async def _get( self, - path: Optional[str] = None, - options: Optional[Dict[str, Any]] = None, + path: str | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Optional[API_RESPONSE]: + ) -> API_RESPONSE | None: full_path = f"{self.base_path}/{path}" if path else self.base_path response = await self._request( method=http_methods.GET, @@ -1673,8 +1662,8 @@ async def _get( async def _put( self, - path: Optional[str] = None, - document: Optional[API_RESPONSE] = None, + path: str | None = None, + document: API_RESPONSE | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -1688,8 +1677,8 @@ async def _put( async def _post( self, - path: Optional[str] = None, - document: Optional[API_DOC] = None, + path: str | None = None, + document: API_DOC | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -1702,8 +1691,8 @@ async def _post( return response def _recast_as_sort_projection( - self, vector: List[float], fields: Optional[List[str]] = None - ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + self, vector: list[float], fields: list[str] | None = None + ) -> tuple[dict[str, Any], dict[str, Any] | None]: """ Given a vector and optionally a list of fields, reformulate them as a sort, projection pair for regular @@ -1718,7 +1707,7 @@ def _recast_as_sort_projection( raise ValueError("Please use the `include_similarity` parameter") # Build the new vector parameter - sort: Dict[str, Any] = {"$vector": vector} + sort: dict[str, Any] = {"$vector": vector} # Build the new fields parameter # Note: do not leave projection={}, make it None @@ -1731,8 +1720,8 @@ def _recast_as_sort_projection( return sort, projection async def get( - self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None - ) -> Optional[API_RESPONSE]: + self, path: str | None = None, timeout_info: TimeoutInfoWideType = None + ) -> API_RESPONSE | None: """ Retrieve a document from the collection by its path. @@ -1749,10 +1738,10 @@ async def get( async def find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1784,14 +1773,14 @@ async def find( async def vector_find( self, - vector: List[float], + vector: list[float], *, limit: int, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> List[API_DOC]: + ) -> list[API_DOC]: """ Perform a vector-based search in the collection. @@ -1836,10 +1825,10 @@ async def vector_find( async def paginate( *, request_method: AsyncPaginableRequestMethod, - options: Optional[Dict[str, Any]], - prefetched: Optional[int] = None, + options: dict[str, Any] | None, + prefetched: int | None = None, timeout_info: TimeoutInfoWideType = None, - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> AsyncGenerator[API_DOC, None]: """ Generate paginated results for a given database query method. @@ -1867,9 +1856,9 @@ async def paginate( if next_page_state is not None and prefetched: async def queued_paginate( - queue: asyncio.Queue[Optional[API_DOC]], + queue: asyncio.Queue[API_DOC | None], request_method: AsyncPaginableRequestMethod, - options: Optional[Dict[str, Any]], + options: dict[str, Any] | None, ) -> None: try: async for doc in AsyncAstraDBCollection.paginate( @@ -1879,7 +1868,7 @@ async def queued_paginate( finally: await queue.put(None) - queue: asyncio.Queue[Optional[API_DOC]] = asyncio.Queue(prefetched) + queue: asyncio.Queue[API_DOC | None] = asyncio.Queue(prefetched) options1 = {**options0, **{"pageState": next_page_state}} asyncio.create_task(queued_paginate(queue, request_method, options1)) for document in response0["data"]["documents"]: @@ -1902,13 +1891,13 @@ async def queued_paginate( def paginated_find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - prefetched: Optional[int] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + prefetched: int | None = None, timeout_info: TimeoutInfoWideType = None, - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> AsyncIterator[API_DOC]: """ Perform a paginated search in the collection. @@ -1948,9 +1937,9 @@ def paginated_find( async def pop( self, - filter: Dict[str, Any], - pop: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + pop: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1985,9 +1974,9 @@ async def pop( async def push( self, - filter: Dict[str, Any], - push: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + push: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2022,12 +2011,12 @@ async def push( async def find_one_and_replace( self, - replacement: Dict[str, Any], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2065,13 +2054,13 @@ async def find_one_and_replace( async def vector_find_one_and_replace( self, - vector: List[float], - replacement: Dict[str, Any], + vector: list[float], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and replace the first matched document. @@ -2106,11 +2095,11 @@ async def vector_find_one_and_replace( async def find_one_and_update( self, - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + update: dict[str, Any], + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2148,13 +2137,13 @@ async def find_one_and_update( async def vector_find_one_and_update( self, - vector: List[float], - update: Dict[str, Any], + vector: list[float], + update: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and update the first matched document. @@ -2190,9 +2179,9 @@ async def vector_find_one_and_update( async def find_one_and_delete( self, - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2226,7 +2215,7 @@ async def find_one_and_delete( return response async def count_documents( - self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None + self, filter: dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Count documents matching a given predicate (expressed as filter). @@ -2254,10 +2243,10 @@ async def count_documents( async def find_one( self, - filter: Optional[Dict[str, Any]] = {}, - projection: Optional[Dict[str, Any]] = {}, - sort: Optional[Dict[str, Any]] = {}, - options: Optional[Dict[str, Any]] = {}, + filter: dict[str, Any] | None = {}, + projection: dict[str, Any] | None = {}, + sort: dict[str, Any] | None = {}, + options: dict[str, Any] | None = {}, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2293,13 +2282,13 @@ async def find_one( async def vector_find_one( self, - vector: List[float], + vector: list[float], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search to find a single document in the collection. @@ -2365,8 +2354,8 @@ async def insert_one( async def insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -2401,13 +2390,13 @@ async def insert_many( async def chunked_insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, chunk_size: int = DEFAULT_INSERT_NUM_DOCUMENTS, concurrency: int = 1, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[API_RESPONSE, Exception]]: + ) -> list[API_RESPONSE | Exception]: """ Insert multiple documents into the collection, handling chunking and optionally with concurrent insertions. @@ -2435,10 +2424,10 @@ async def chunked_insert_many( sem = asyncio.Semaphore(concurrency) async def concurrent_insert_many( - docs: List[API_DOC], + docs: list[API_DOC], index: int, partial_failures_allowed: bool, - ) -> Union[API_RESPONSE, Exception]: + ) -> API_RESPONSE | Exception: async with sem: logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}") try: @@ -2488,11 +2477,11 @@ async def concurrent_insert_many( async def update_one( self, - filter: Dict[str, Any], - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, ) -> API_RESPONSE: """ Update a single document in the collection. @@ -2527,9 +2516,9 @@ async def update_one( async def update_many( self, - filter: Dict[str, Any], - update: Dict[str, Any], - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2582,7 +2571,7 @@ async def replace( async def delete_one( self, id: str, - sort: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2614,8 +2603,8 @@ async def delete_one( async def delete_one_by_predicate( self, - filter: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2647,7 +2636,7 @@ async def delete_one_by_predicate( async def delete_many( self, - filter: Dict[str, Any], + filter: dict[str, Any], skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -2682,8 +2671,8 @@ async def delete_many( return response async def chunked_delete_many( - self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None - ) -> List[API_RESPONSE]: + self, filter: dict[str, Any], timeout_info: TimeoutInfoWideType = None + ) -> list[API_RESPONSE]: """ Delete many documents from the collection based on a filter condition, chaining several API calls until exhaustion of the documents to delete. @@ -2818,7 +2807,7 @@ async def upsert_many( concurrency: int = 1, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[str, Exception]]: + ) -> list[str | Exception]: """ Emulate an upsert operation for multiple documents in the collection. This method attempts to insert the documents. @@ -2860,13 +2849,13 @@ class AstraDB: def __init__( self, - token: Optional[str], + token: str | None, api_endpoint: str, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Initialize an Astra DB instance. @@ -2940,13 +2929,13 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDB: return AstraDB( token=token or self.token, @@ -2971,8 +2960,8 @@ def to_async(self) -> AsyncAstraDB: def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -2980,9 +2969,9 @@ def set_caller( def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -3005,7 +2994,7 @@ def _request( return response def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return self._request( method=http_methods.POST, @@ -3028,7 +3017,7 @@ def collection(self, collection_name: str) -> AstraDBCollection: def get_collections( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -3066,10 +3055,10 @@ def create_collection( self, collection_name: str, *, - options: Optional[Dict[str, Any]] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service_dict: Optional[Dict[str, str]] = None, + options: dict[str, Any] | None = None, + dimension: int | None = None, + metric: str | None = None, + service_dict: dict[str, str] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> AstraDBCollection: """ @@ -3169,13 +3158,13 @@ def delete_collection( class AsyncAstraDB: def __init__( self, - token: Optional[str], + token: str | None, api_endpoint: str, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Initialize an Astra DB instance. @@ -3252,22 +3241,22 @@ async def __aenter__(self) -> AsyncAstraDB: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: await self.client.aclose() def copy( self, *, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncAstraDB: return AsyncAstraDB( token=token or self.token, @@ -3292,8 +3281,8 @@ def to_sync(self) -> AstraDB: def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -3301,9 +3290,9 @@ def set_caller( async def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -3326,7 +3315,7 @@ async def _request( return response async def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return await self._request( method=http_methods.POST, @@ -3352,7 +3341,7 @@ async def collection(self, collection_name: str) -> AsyncAstraDBCollection: async def get_collections( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -3390,10 +3379,10 @@ async def create_collection( self, collection_name: str, *, - options: Optional[Dict[str, Any]] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service_dict: Optional[Dict[str, str]] = None, + options: dict[str, Any] | None = None, + dimension: int | None = None, + metric: str | None = None, + service_dict: dict[str, str] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> AsyncAstraDBCollection: """ diff --git a/astrapy/core/ops.py b/astrapy/core/ops.py index 9cc90eb8..2b630d93 100644 --- a/astrapy/core/ops.py +++ b/astrapy/core/ops.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, TypedDict, Union, cast +from typing import Any, TypedDict, cast import httpx @@ -38,11 +38,11 @@ class AstraDBOpsConstructorParams(TypedDict): - token: Union[str, None] - dev_ops_url: Optional[str] - dev_ops_api_version: Optional[str] - caller_name: Optional[str] - caller_version: Optional[str] + token: str | None + dev_ops_url: str | None + dev_ops_api_version: str | None + caller_name: str | None + caller_version: str | None class AstraDBOps: @@ -52,11 +52,11 @@ class AstraDBOps: def __init__( self, - token: Union[str, None], - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -74,7 +74,7 @@ def __init__( dev_ops_api_version or DEFAULT_DEV_OPS_API_VERSION ).strip("/") - self.token: Union[str, None] + self.token: str | None if token is not None: self.token = "Bearer " + token else: @@ -98,11 +98,11 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - token: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDBOps: return AstraDBOps( token=token or self.constructor_params["token"], @@ -115,8 +115,8 @@ def copy( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -125,8 +125,8 @@ def _ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: _options = {} if options is None else options @@ -151,8 +151,8 @@ async def _async_ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: _options = {} if options is None else options @@ -177,8 +177,8 @@ def _json_ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: _options = {} if options is None else options @@ -204,8 +204,8 @@ async def _async_json_ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: _options = {} if options is None else options @@ -229,7 +229,7 @@ async def _async_json_ops_request( def get_databases( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -252,7 +252,7 @@ def get_databases( async def async_get_databases( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -275,9 +275,9 @@ async def async_get_databases( def create_database( self, - database_definition: Optional[Dict[str, Any]] = None, + database_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a new database. @@ -305,9 +305,9 @@ def create_database( async def async_create_database( self, - database_definition: Optional[Dict[str, Any]] = None, + database_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a new database - async version of the method. @@ -392,7 +392,7 @@ async def async_terminate_database( def get_database( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -418,7 +418,7 @@ def get_database( async def async_get_database( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -446,7 +446,7 @@ def create_keyspace( database: str = "", keyspace: str = "", timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a keyspace in a specified database. @@ -476,7 +476,7 @@ async def async_create_keyspace( database: str = "", keyspace: str = "", timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a keyspace in a specified database - async version of the method. @@ -600,7 +600,7 @@ def unpark_database( def resize_database( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -623,7 +623,7 @@ def resize_database( def reset_database_password( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -682,7 +682,7 @@ def get_datacenters( def create_datacenter( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -745,7 +745,7 @@ def get_access_list( def replace_access_list( self, database: str = "", - access_list: Optional[Dict[str, Any]] = None, + access_list: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -768,7 +768,7 @@ def replace_access_list( def update_access_list( self, database: str = "", - access_list: Optional[Dict[str, Any]] = None, + access_list: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -791,7 +791,7 @@ def update_access_list( def add_access_list_address( self, database: str = "", - address: Optional[Dict[str, Any]] = None, + address: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -873,7 +873,7 @@ def create_datacenter_private_link( self, database: str = "", datacenter: str = "", - private_link: Optional[Dict[str, Any]] = None, + private_link: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -898,7 +898,7 @@ def create_datacenter_endpoint( self, database: str = "", datacenter: str = "", - endpoint: Optional[Dict[str, Any]] = None, + endpoint: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -923,7 +923,7 @@ def update_datacenter_endpoint( self, database: str = "", datacenter: str = "", - endpoint: Dict[str, Any] = {}, + endpoint: dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1035,7 +1035,7 @@ def get_roles(self, timeout_info: TimeoutInfoWideType = None) -> OPS_API_RESPONS def create_role( self, - role_definition: Optional[Dict[str, Any]] = None, + role_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1075,7 +1075,7 @@ def get_role( def update_role( self, role: str = "", - role_definition: Optional[Dict[str, Any]] = None, + role_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1115,7 +1115,7 @@ def delete_role( def invite_user( self, - user_definition: Optional[Dict[str, Any]] = None, + user_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1186,7 +1186,7 @@ def remove_user( def update_user_roles( self, user: str = "", - roles: Optional[Dict[str, Any]] = None, + roles: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1219,7 +1219,7 @@ def get_clients(self, timeout_info: TimeoutInfoWideType = None) -> OPS_API_RESPO def create_token( self, - roles: Optional[Dict[str, Any]] = None, + roles: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1360,7 +1360,7 @@ def get_streaming_tenants( def create_streaming_tenant( self, - tenant: Optional[Dict[str, Any]] = None, + tenant: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ diff --git a/astrapy/core/utils.py b/astrapy/core/utils.py index dad44019..6eece183 100644 --- a/astrapy/core/utils.py +++ b/astrapy/core/utils.py @@ -18,7 +18,7 @@ import json import logging import time -from typing import Any, Dict, Iterable, List, Optional, TypedDict, Union, cast +from typing import Any, Dict, Iterable, TypedDict, Union, cast import httpx @@ -57,7 +57,7 @@ class http_methods: user_agent_astrapy = f"{package_name}/{__version__}" -def detect_ragstack_user_agent() -> Optional[str]: +def detect_ragstack_user_agent() -> str | None: from importlib import metadata from importlib.metadata import PackageNotFoundError @@ -77,9 +77,9 @@ def detect_ragstack_user_agent() -> Optional[str]: def log_request( method: str, url: str, - params: Optional[Dict[str, Any]], - headers: Dict[str, str], - json_data: Optional[Dict[str, Any]], + params: dict[str, Any] | None, + headers: dict[str, str], + json_data: dict[str, Any] | None, ) -> None: """ Log the details of an HTTP request for debugging purposes. @@ -116,8 +116,8 @@ def log_response(r: httpx.Response) -> None: def user_agent_string( - caller_name: Optional[str], caller_version: Optional[str] -) -> Optional[str]: + caller_name: str | None, caller_version: str | None +) -> str | None: if caller_name: if caller_version: return f"{caller_name}/{caller_version}" @@ -127,9 +127,7 @@ def user_agent_string( return None -def compose_user_agent( - caller_name: Optional[str], caller_version: Optional[str] -) -> str: +def compose_user_agent(caller_name: str | None, caller_version: str | None) -> str: user_agent_caller = user_agent_string(caller_name, caller_version) all_user_agents = [ ua_block @@ -152,7 +150,7 @@ class TimeoutInfo(TypedDict, total=False): TimeoutInfoWideType = Union[TimeoutInfo, float, None] -def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> Union[httpx.Timeout, None]: +def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> httpx.Timeout | None: if timeout_info is None: return None if isinstance(timeout_info, float) or isinstance(timeout_info, int): @@ -170,15 +168,15 @@ def make_request( client: httpx.Client, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: """ Make an HTTP request to a specified URL. @@ -233,15 +231,15 @@ async def amake_request( client: httpx.AsyncClient, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - path: Optional[str], - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + path: str | None, + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: """ Make an HTTP request to a specified URL. @@ -292,7 +290,7 @@ async def amake_request( return r -def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: +def make_payload(top_level: str, **kwargs: Any) -> dict[str, Any]: """ Construct a JSON payload for an HTTP request with a specified top-level key. @@ -307,7 +305,7 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: for key, value in kwargs.items(): params[key] = value - json_query: Dict[str, Any] = {top_level: {}} + json_query: dict[str, Any] = {top_level: {}} # Adding keys only if they're provided for key, value in params.items(): @@ -317,7 +315,7 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: return json_query -def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: +def convert_vector_to_floats(vector: Iterable[Any]) -> list[float]: """ Convert a vector of strings to a vector of floats. @@ -341,36 +339,36 @@ def is_list_of_floats(vector: Iterable[Any]) -> bool: def convert_to_ejson_date_object( - date_value: Union[datetime.date, datetime.datetime] -) -> Dict[str, int]: + date_value: datetime.date | datetime.datetime, +) -> dict[str, int]: return {"$date": int(time.mktime(date_value.timetuple()) * 1000)} -def convert_to_ejson_uuid_object(uuid_value: UUID) -> Dict[str, str]: +def convert_to_ejson_uuid_object(uuid_value: UUID) -> dict[str, str]: return {"$uuid": str(uuid_value)} -def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> Dict[str, str]: +def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> dict[str, str]: return {"$objectId": str(objectid_value)} def convert_ejson_date_object_to_datetime( - date_object: Dict[str, int] + date_object: dict[str, int], ) -> datetime.datetime: return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0) -def convert_ejson_uuid_object_to_uuid(uuid_object: Dict[str, str]) -> UUID: +def convert_ejson_uuid_object_to_uuid(uuid_object: dict[str, str]) -> UUID: return UUID(uuid_object["$uuid"]) def convert_ejson_objectid_object_to_objectid( - objectid_object: Dict[str, str] + objectid_object: dict[str, str], ) -> ObjectId: return ObjectId(objectid_object["$objectId"]) -def _normalize_payload_value(path: List[str], value: Any) -> Any: +def _normalize_payload_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ @@ -401,9 +399,7 @@ def _normalize_payload_value(path: List[str], value: Any) -> Any: return value -def normalize_for_api( - payload: Union[Dict[str, Any], None] -) -> Union[Dict[str, Any], None]: +def normalize_for_api(payload: dict[str, Any] | None) -> dict[str, Any] | None: """ Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key @@ -422,7 +418,7 @@ def normalize_for_api( return payload -def _restore_response_value(path: List[str], value: Any) -> Any: +def _restore_response_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ diff --git a/astrapy/cursors.py b/astrapy/cursors.py index fab0b314..bf336b11 100644 --- a/astrapy/cursors.py +++ b/astrapy/cursors.py @@ -23,15 +23,12 @@ TYPE_CHECKING, Any, Callable, - Dict, Generic, Iterable, Iterator, - List, Optional, Tuple, TypeVar, - Union, ) from astrapy.constants import ( @@ -58,7 +55,7 @@ IndexPairType = Tuple[str, Optional[int]] -def _maybe_valid_list_index(key_block: str) -> Optional[int]: +def _maybe_valid_list_index(key_block: str) -> int | None: # '0', '1' is good. '00', '01', '-30' are not. try: kb_index = int(key_block) @@ -72,9 +69,8 @@ def _maybe_valid_list_index(key_block: str) -> Optional[int]: def _create_document_key_extractor( key: str, -) -> Callable[[Dict[str, Any]], Iterable[Any]]: - - key_blocks0: List[IndexPairType] = [ +) -> Callable[[dict[str, Any]], Iterable[Any]]: + key_blocks0: list[IndexPairType] = [ (kb_str, _maybe_valid_list_index(kb_str)) for kb_str in key.split(".") ] if key_blocks0 == []: @@ -83,7 +79,7 @@ def _create_document_key_extractor( raise ValueError("Field path components cannot be empty") def _extract_with_key_blocks( - key_blocks: List[IndexPairType], value: Any + key_blocks: list[IndexPairType], value: Any ) -> Iterable[Any]: if key_blocks == []: if isinstance(value, list): @@ -123,7 +119,7 @@ def _extract_with_key_blocks( # keyblocks are deeper than the document. Nothing to extract. return - def _item_extractor(document: Dict[str, Any]) -> Iterable[Any]: + def _item_extractor(document: dict[str, Any]) -> Iterable[Any]: return _extract_with_key_blocks(key_blocks=key_blocks0, value=document) return _item_extractor @@ -148,7 +144,7 @@ def _reduce_distinct_key_to_safe(distinct_key: str) -> str: return ".".join(valid_portion) -def _hash_document(document: Dict[str, Any]) -> str: +def _hash_document(document: dict[str, Any]) -> str: _normalized_item = normalize_payload_value(path=[], value=document) _normalized_json = json.dumps( _normalized_item, sort_keys=True, separators=(",", ":") @@ -165,7 +161,7 @@ class _LookAheadIterator: def __init__(self, iterator: Iterator[DocumentType]): self.iterator = iterator - self.preread_item: Optional[DocumentType] = None + self.preread_item: DocumentType | None = None self.has_preread = False self.preread_exhausted = False @@ -201,7 +197,7 @@ class _AsyncLookAheadIterator: def __init__(self, async_iterator: AsyncIterator[DocumentType]): self.async_iterator = async_iterator - self.preread_item: Optional[DocumentType] = None + self.preread_item: DocumentType | None = None self.has_preread = False self.preread_exhausted = False @@ -237,30 +233,30 @@ class BaseCursor: See classes Cursor and AsyncCursor for more information. """ - _collection: Union[Collection, AsyncCollection] - _filter: Optional[Dict[str, Any]] - _projection: Optional[ProjectionType] - _max_time_ms: Optional[int] - _overall_max_time_ms: Optional[int] - _started_time_s: Optional[float] - _limit: Optional[int] - _skip: Optional[int] - _include_similarity: Optional[bool] - _include_sort_vector: Optional[bool] - _sort: Optional[Dict[str, Any]] + _collection: Collection | AsyncCollection + _filter: dict[str, Any] | None + _projection: ProjectionType | None + _max_time_ms: int | None + _overall_max_time_ms: int | None + _started_time_s: float | None + _limit: int | None + _skip: int | None + _include_similarity: bool | None + _include_sort_vector: bool | None + _sort: dict[str, Any] | None _started: bool _retrieved: int _alive: bool - _iterator: Optional[Union[_LookAheadIterator, _AsyncLookAheadIterator]] = None - _api_response_status: Optional[Dict[str, Any]] + _iterator: _LookAheadIterator | _AsyncLookAheadIterator | None = None + _api_response_status: dict[str, Any] | None def __init__( self, - collection: Union[Collection, AsyncCollection], - filter: Optional[Dict[str, Any]], - projection: Optional[ProjectionType], - max_time_ms: Optional[int], - overall_max_time_ms: Optional[int], + collection: Collection | AsyncCollection, + filter: dict[str, Any] | None, + projection: ProjectionType | None, + max_time_ms: int | None, + overall_max_time_ms: int | None, ) -> None: raise NotImplementedError @@ -315,15 +311,15 @@ def _ensure_not_started(self) -> None: def _copy( self: BC, *, - projection: Optional[ProjectionType] = None, - max_time_ms: Optional[int] = None, - overall_max_time_ms: Optional[int] = None, - limit: Optional[int] = None, - skip: Optional[int] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - started: Optional[bool] = None, - sort: Optional[Dict[str, Any]] = None, + projection: ProjectionType | None = None, + max_time_ms: int | None = None, + overall_max_time_ms: int | None = None, + limit: int | None = None, + skip: int | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + started: bool | None = None, + sort: dict[str, Any] | None = None, ) -> BC: new_cursor = self.__class__( collection=self._collection, @@ -417,7 +413,7 @@ def cursor_id(self) -> int: return id(self) - def limit(self: BC, limit: Optional[int]) -> BC: + def limit(self: BC, limit: int | None) -> BC: """ Set a new `limit` value for this cursor. @@ -433,7 +429,7 @@ def limit(self: BC, limit: Optional[int]) -> BC: self._limit = limit if limit != 0 else None return self - def include_similarity(self: BC, include_similarity: Optional[bool]) -> BC: + def include_similarity(self: BC, include_similarity: bool | None) -> BC: """ Set a new `include_similarity` value for this cursor. @@ -449,7 +445,7 @@ def include_similarity(self: BC, include_similarity: Optional[bool]) -> BC: self._include_similarity = include_similarity return self - def include_sort_vector(self: BC, include_sort_vector: Optional[bool]) -> BC: + def include_sort_vector(self: BC, include_sort_vector: bool | None) -> BC: """ Set a new `include_sort_vector` value for this cursor. @@ -487,7 +483,7 @@ def rewind(self: BC) -> BC: self._iterator = None return self - def skip(self: BC, skip: Optional[int]) -> BC: + def skip(self: BC, skip: int | None) -> BC: """ Set a new `skip` value for this cursor. @@ -509,7 +505,7 @@ def skip(self: BC, skip: Optional[int]) -> BC: def sort( self: BC, - sort: Optional[Dict[str, Any]], + sort: dict[str, Any] | None, ) -> BC: """ Set a new `sort` value for this cursor. @@ -581,10 +577,10 @@ class Cursor(BaseCursor): def __init__( self, collection: Collection, - filter: Optional[Dict[str, Any]], - projection: Optional[ProjectionType], - max_time_ms: Optional[int], - overall_max_time_ms: Optional[int], + filter: dict[str, Any] | None, + projection: ProjectionType | None, + max_time_ms: int | None, + overall_max_time_ms: int | None, ) -> None: self._collection: Collection = collection self._filter = filter @@ -594,17 +590,17 @@ def __init__( self._max_time_ms = min(max_time_ms, overall_max_time_ms) else: self._max_time_ms = max_time_ms - self._limit: Optional[int] = None - self._skip: Optional[int] = None - self._include_similarity: Optional[bool] = None - self._include_sort_vector: Optional[bool] = None - self._sort: Optional[Dict[str, Any]] = None + self._limit: int | None = None + self._skip: int | None = None + self._include_similarity: bool | None = None + self._include_sort_vector: bool | None = None + self._sort: dict[str, Any] | None = None self._started = False self._retrieved = 0 self._alive = True # - self._iterator: Optional[_LookAheadIterator] = None - self._api_response_status: Optional[Dict[str, Any]] = None + self._iterator: _LookAheadIterator | None = None + self._api_response_status: dict[str, Any] | None = None def __iter__(self) -> Cursor: self._ensure_alive() @@ -638,7 +634,7 @@ def __next__(self) -> DocumentType: self._alive = False raise - def get_sort_vector(self) -> Optional[List[float]]: + def get_sort_vector(self) -> list[float] | None: """ Return the vector used in this ANN search, if applicable. If this is not an ANN search, or it was invoked without the @@ -695,7 +691,7 @@ def _create_iterator(self) -> _LookAheadIterator: } def _find_iterator() -> Iterator[DocumentType]: - next_page_state: Optional[str] = None + next_page_state: str | None = None # resp_0 = self._collection.command( body=f0_payload, @@ -762,7 +758,7 @@ def collection(self) -> Collection: return self._collection - def distinct(self, key: str, max_time_ms: Optional[int] = None) -> List[Any]: + def distinct(self, key: str, max_time_ms: int | None = None) -> list[Any]: """ Compute a list of unique values for a specific field across all documents the cursor iterates through. @@ -860,10 +856,10 @@ class AsyncCursor(BaseCursor): def __init__( self, collection: AsyncCollection, - filter: Optional[Dict[str, Any]], - projection: Optional[ProjectionType], - max_time_ms: Optional[int], - overall_max_time_ms: Optional[int], + filter: dict[str, Any] | None, + projection: ProjectionType | None, + max_time_ms: int | None, + overall_max_time_ms: int | None, ) -> None: self._collection: AsyncCollection = collection self._filter = filter @@ -873,17 +869,17 @@ def __init__( self._max_time_ms = min(max_time_ms, overall_max_time_ms) else: self._max_time_ms = max_time_ms - self._limit: Optional[int] = None - self._skip: Optional[int] = None - self._include_similarity: Optional[bool] = None - self._include_sort_vector: Optional[bool] = None - self._sort: Optional[Dict[str, Any]] = None + self._limit: int | None = None + self._skip: int | None = None + self._include_similarity: bool | None = None + self._include_sort_vector: bool | None = None + self._sort: dict[str, Any] | None = None self._started = False self._retrieved = 0 self._alive = True # - self._iterator: Optional[_AsyncLookAheadIterator] = None - self._api_response_status: Optional[Dict[str, Any]] = None + self._iterator: _AsyncLookAheadIterator | None = None + self._api_response_status: dict[str, Any] | None = None def __aiter__(self) -> AsyncCursor: self._ensure_alive() @@ -917,7 +913,7 @@ async def __anext__(self) -> DocumentType: self._alive = False raise - async def get_sort_vector(self) -> Optional[List[float]]: + async def get_sort_vector(self) -> list[float] | None: """ Return the vector used in this ANN search, if applicable. If this is not an ANN search, or it was invoked without the @@ -1034,12 +1030,12 @@ async def _find_iterator() -> AsyncIterator[DocumentType]: def _to_sync( self: AsyncCursor, *, - limit: Optional[int] = None, - skip: Optional[int] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - started: Optional[bool] = None, - sort: Optional[Dict[str, Any]] = None, + limit: int | None = None, + skip: int | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + started: bool | None = None, + sort: dict[str, Any] | None = None, ) -> Cursor: new_cursor = Cursor( collection=self._collection.to_sync(), @@ -1079,7 +1075,7 @@ def collection(self) -> AsyncCollection: return self._collection - async def distinct(self, key: str, max_time_ms: Optional[int] = None) -> List[Any]: + async def distinct(self, key: str, max_time_ms: int | None = None) -> list[Any]: """ Compute a list of unique values for a specific field across all documents the cursor iterates through. @@ -1142,7 +1138,7 @@ class CommandCursor(Generic[T]): (such as the database `list_collections` method). """ - def __init__(self, address: str, items: List[T]) -> None: + def __init__(self, address: str, items: list[T]) -> None: self._address = address self.items = items self.iterable = items.__iter__() @@ -1225,7 +1221,7 @@ class AsyncCommandCursor(Generic[T]): (such as the database `list_collections` method). """ - def __init__(self, address: str, items: List[T]) -> None: + def __init__(self, address: str, items: list[T]) -> None: self._address = address self.items = items self.iterable = items.__iter__() diff --git a/astrapy/database.py b/astrapy/database.py index 2b6fc5d4..0f5705af 100644 --- a/astrapy/database.py +++ b/astrapy/database.py @@ -17,7 +17,7 @@ import logging import warnings from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any import deprecation @@ -63,13 +63,13 @@ def _normalize_create_collection_options( - dimension: Optional[int], - metric: Optional[str], - service: Optional[Union[CollectionVectorServiceOptions, Dict[str, Any]]], - indexing: Optional[Dict[str, Any]], - default_id_type: Optional[str], - additional_options: Optional[Dict[str, Any]], -) -> Dict[str, Any]: + dimension: int | None, + metric: str | None, + service: CollectionVectorServiceOptions | dict[str, Any] | None, + indexing: dict[str, Any] | None, + default_id_type: str | None, + additional_options: dict[str, Any] | None, +) -> dict[str, Any]: """Raise errors related to invalid input, and return a ready-to-send payload.""" is_vector: bool if service is not None or dimension is not None: @@ -82,7 +82,7 @@ def _normalize_create_collection_options( "create_collection method." ) # prepare the payload - service_dict: Optional[Dict[str, Any]] + service_dict: dict[str, Any] | None if service is not None: service_dict = service if isinstance(service, dict) else service.as_dict() else: @@ -172,15 +172,15 @@ class Database: def __init__( self, api_endpoint: str, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> None: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -188,8 +188,8 @@ def __init__( ) self.environment = (environment or Environment.PROD).lower() # - _api_path: Optional[str] - _api_version: Optional[str] + _api_path: str | None + _api_version: str | None if api_path is None: _api_path = API_PATH_ENV_MAP[self.environment] else: @@ -204,7 +204,7 @@ def __init__( self.api_version = _api_version # enforce defaults if on Astra DB: - self._using_keyspace: Optional[str] + self._using_keyspace: str | None if keyspace_param is None and self.environment in Environment.astra_db_values: self._using_keyspace = DEFAULT_ASTRA_DB_KEYSPACE else: @@ -217,7 +217,7 @@ def __init__( self.caller_name = caller_name self.caller_version = caller_version self._api_commander = self._get_api_commander(keyspace=self.keyspace) - self._name: Optional[str] = None + self._name: str | None = None def __getattr__(self, collection_name: str) -> Collection: return self.get_collection(name=collection_name) @@ -227,12 +227,12 @@ def __getitem__(self, collection_name: str) -> Collection: def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - keyspace_desc: Optional[str] + keyspace_desc: str | None if self.keyspace is None: keyspace_desc = "keyspace not set" else: @@ -257,7 +257,7 @@ def __eq__(self, other: Any) -> bool: else: return False - def _get_api_commander(self, keyspace: Optional[str]) -> Optional[APICommander]: + def _get_api_commander(self, keyspace: str | None) -> APICommander | None: """ Instantiate a new APICommander based on the properties of this class and a provided keyspace. @@ -286,12 +286,12 @@ def _get_api_commander(self, keyspace: Optional[str]) -> Optional[APICommander]: ) return api_commander - def _get_driver_commander(self, keyspace: Optional[str]) -> APICommander: + def _get_driver_commander(self, keyspace: str | None) -> APICommander: """ Building on _get_api_commander, fall back to class keyspace in creating/returning a commander, and in any case raise an error if not set. """ - driver_commander: Optional[APICommander] + driver_commander: APICommander | None if keyspace: driver_commander = self._get_api_commander(keyspace=keyspace) else: @@ -306,15 +306,15 @@ def _get_driver_commander(self, keyspace: Optional[str]) -> APICommander: def _copy( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -334,10 +334,10 @@ def _copy( def with_options( self, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Database: """ Create a clone of this database with some changed attributes. @@ -375,15 +375,15 @@ def with_options( def to_async( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase from this one. Save for the arguments @@ -436,8 +436,8 @@ def to_async( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -583,7 +583,7 @@ def name(self) -> str: return self._name @property - def namespace(self) -> Optional[str]: + def namespace(self) -> str | None: """ The namespace this database uses as target for all commands when no method-call-specific namespace is specified. @@ -609,7 +609,7 @@ def namespace(self) -> Optional[str]: return self.keyspace @property - def keyspace(self) -> Optional[str]: + def keyspace(self) -> str | None: """ The keyspace this database uses as target for all commands when no method-call-specific keyspace is specified. @@ -628,10 +628,10 @@ def get_collection( self, name: str, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + keyspace: str | None = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> Collection: """ Spawn a `Collection` object instance representing a collection @@ -711,18 +711,18 @@ def create_collection( self, name: str, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service: Optional[Union[CollectionVectorServiceOptions, Dict[str, Any]]] = None, - indexing: Optional[Dict[str, Any]] = None, - default_id_type: Optional[str] = None, - additional_options: Optional[Dict[str, Any]] = None, - check_exists: Optional[bool] = None, - max_time_ms: Optional[int] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + keyspace: str | None = None, + namespace: str | None = None, + dimension: int | None = None, + metric: str | None = None, + service: CollectionVectorServiceOptions | dict[str, Any] | None = None, + indexing: dict[str, Any] | None = None, + default_id_type: str | None = None, + additional_options: dict[str, Any] | None = None, + check_exists: bool | None = None, + max_time_ms: int | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> Collection: """ Creates a collection on the database and return the Collection @@ -852,10 +852,10 @@ def create_collection( def drop_collection( self, - name_or_collection: Union[str, Collection], + name_or_collection: str | Collection, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a collection from the database, along with all documents therein. @@ -883,7 +883,7 @@ def drop_collection( # lazy importing here against circular-import error from astrapy.collection import Collection - _keyspace: Optional[str] + _keyspace: str | None _collection_name: str if isinstance(name_or_collection, Collection): _keyspace = name_or_collection.keyspace @@ -904,9 +904,9 @@ def drop_collection( def list_collections( self, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, + keyspace: str | None = None, + namespace: str | None = None, + max_time_ms: int | None = None, ) -> CommandCursor[CollectionDescriptor]: """ List all collections in a given keyspace for this database. @@ -964,10 +964,10 @@ def list_collections( def list_collection_names( self, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, - ) -> List[str]: + keyspace: str | None = None, + namespace: str | None = None, + max_time_ms: int | None = None, + ) -> list[str]: """ List the names of all collections in a given keyspace of this database. @@ -991,7 +991,7 @@ def list_collection_names( ) driver_commander = self._get_driver_commander(keyspace=keyspace_param) - gc_payload: Dict[str, Any] = {"findCollections": {}} + gc_payload: dict[str, Any] = {"findCollections": {}} logger.info("findCollections") gc_response = driver_commander.request( payload=gc_payload, @@ -1009,14 +1009,14 @@ def list_collection_names( def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - collection_name: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + collection_name: str | None = None, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this database with an arbitrary, caller-provided payload. @@ -1080,9 +1080,9 @@ def command( def get_database_admin( self, *, - token: Optional[Union[str, TokenProvider]] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> DatabaseAdmin: """ Return a DatabaseAdmin object corresponding to this database, for @@ -1207,15 +1207,15 @@ class AsyncDatabase: def __init__( self, api_endpoint: str, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> None: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -1223,8 +1223,8 @@ def __init__( ) self.environment = (environment or Environment.PROD).lower() # - _api_path: Optional[str] - _api_version: Optional[str] + _api_path: str | None + _api_version: str | None if api_path is None: _api_path = API_PATH_ENV_MAP[self.environment] else: @@ -1239,7 +1239,7 @@ def __init__( self.api_version = _api_version # enforce defaults if on Astra DB: - self._using_keyspace: Optional[str] + self._using_keyspace: str | None if keyspace_param is None and self.environment in Environment.astra_db_values: self._using_keyspace = DEFAULT_ASTRA_DB_KEYSPACE else: @@ -1252,7 +1252,7 @@ def __init__( self.caller_name = caller_name self.caller_version = caller_version self._api_commander = self._get_api_commander(keyspace=self.keyspace) - self._name: Optional[str] = None + self._name: str | None = None def __getattr__(self, collection_name: str) -> AsyncCollection: return self.to_sync().get_collection(name=collection_name).to_async() @@ -1262,12 +1262,12 @@ def __getitem__(self, collection_name: str) -> AsyncCollection: def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - keyspace_desc: Optional[str] + keyspace_desc: str | None if self.keyspace is None: keyspace_desc = "keyspace not set" else: @@ -1292,7 +1292,7 @@ def __eq__(self, other: Any) -> bool: else: return False - def _get_api_commander(self, keyspace: Optional[str]) -> Optional[APICommander]: + def _get_api_commander(self, keyspace: str | None) -> APICommander | None: """ Instantiate a new APICommander based on the properties of this class and a provided keyspace. @@ -1321,12 +1321,12 @@ def _get_api_commander(self, keyspace: Optional[str]) -> Optional[APICommander]: ) return api_commander - def _get_driver_commander(self, keyspace: Optional[str]) -> APICommander: + def _get_driver_commander(self, keyspace: str | None) -> APICommander: """ Building on _get_api_commander, fall back to class keyspace in creating/returning a commander, and in any case raise an error if not set. """ - driver_commander: Optional[APICommander] + driver_commander: APICommander | None if keyspace: driver_commander = self._get_api_commander(keyspace=keyspace) else: @@ -1343,9 +1343,9 @@ async def __aenter__(self) -> AsyncDatabase: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: if self._api_commander is not None: await self._api_commander.__aexit__( @@ -1357,15 +1357,15 @@ async def __aexit__( def _copy( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: keyspace_param = check_namespace_keyspace( keyspace=keyspace, @@ -1385,10 +1385,10 @@ def _copy( def with_options( self, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncDatabase: """ Create a clone of this database with some changed attributes. @@ -1427,15 +1427,15 @@ def with_options( def to_sync( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + keyspace: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: """ Create a (synchronous) Database from this one. Save for the arguments @@ -1489,8 +1489,8 @@ def to_sync( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -1636,7 +1636,7 @@ def name(self) -> str: return self._name @property - def namespace(self) -> Optional[str]: + def namespace(self) -> str | None: """ The namespace this database uses as target for all commands when no method-call-specific namespace is specified. @@ -1662,7 +1662,7 @@ def namespace(self) -> Optional[str]: return self.keyspace @property - def keyspace(self) -> Optional[str]: + def keyspace(self) -> str | None: """ The keyspace this database uses as target for all commands when no method-call-specific keyspace is specified. @@ -1681,10 +1681,10 @@ async def get_collection( self, name: str, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + keyspace: str | None = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> AsyncCollection: """ Spawn an `AsyncCollection` object instance representing a collection @@ -1767,18 +1767,18 @@ async def create_collection( self, name: str, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service: Optional[Union[CollectionVectorServiceOptions, Dict[str, Any]]] = None, - indexing: Optional[Dict[str, Any]] = None, - default_id_type: Optional[str] = None, - additional_options: Optional[Dict[str, Any]] = None, - check_exists: Optional[bool] = None, - max_time_ms: Optional[int] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + keyspace: str | None = None, + namespace: str | None = None, + dimension: int | None = None, + metric: str | None = None, + service: CollectionVectorServiceOptions | dict[str, Any] | None = None, + indexing: dict[str, Any] | None = None, + default_id_type: str | None = None, + additional_options: dict[str, Any] | None = None, + check_exists: bool | None = None, + max_time_ms: int | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> AsyncCollection: """ Creates a collection on the database and return the AsyncCollection @@ -1911,10 +1911,10 @@ async def create_collection( async def drop_collection( self, - name_or_collection: Union[str, AsyncCollection], + name_or_collection: str | AsyncCollection, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a collection from the database, along with all documents therein. @@ -1942,7 +1942,7 @@ async def drop_collection( # lazy importing here against circular-import error from astrapy.collection import AsyncCollection - keyspace: Optional[str] + keyspace: str | None _collection_name: str if isinstance(name_or_collection, AsyncCollection): keyspace = name_or_collection.keyspace @@ -1963,9 +1963,9 @@ async def drop_collection( def list_collections( self, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, + keyspace: str | None = None, + namespace: str | None = None, + max_time_ms: int | None = None, ) -> AsyncCommandCursor[CollectionDescriptor]: """ List all collections in a given keyspace for this database. @@ -2025,10 +2025,10 @@ def list_collections( async def list_collection_names( self, *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, - ) -> List[str]: + keyspace: str | None = None, + namespace: str | None = None, + max_time_ms: int | None = None, + ) -> list[str]: """ List the names of all collections in a given keyspace of this database. @@ -2052,7 +2052,7 @@ async def list_collection_names( ) driver_commander = self._get_driver_commander(keyspace=keyspace_param) - gc_payload: Dict[str, Any] = {"findCollections": {}} + gc_payload: dict[str, Any] = {"findCollections": {}} logger.info("findCollections") gc_response = await driver_commander.async_request( payload=gc_payload, @@ -2070,14 +2070,14 @@ async def list_collection_names( async def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, - keyspace: Optional[str] = None, - namespace: Optional[str] = None, - collection_name: Optional[str] = None, + keyspace: str | None = None, + namespace: str | None = None, + collection_name: str | None = None, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this database with an arbitrary, caller-provided payload. @@ -2145,9 +2145,9 @@ async def command( def get_database_admin( self, *, - token: Optional[Union[str, TokenProvider]] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> DatabaseAdmin: """ Return a DatabaseAdmin object corresponding to this database, for diff --git a/astrapy/exceptions.py b/astrapy/exceptions.py index 80a0f451..e24f3a71 100644 --- a/astrapy/exceptions.py +++ b/astrapy/exceptions.py @@ -16,7 +16,7 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any import httpx @@ -55,15 +55,15 @@ class DevOpsAPIHttpException(DevOpsAPIException, httpx.HTTPStatusError): found in the response. """ - text: Optional[str] - error_descriptors: List[DevOpsAPIErrorDescriptor] + text: str | None + error_descriptors: list[DevOpsAPIErrorDescriptor] def __init__( self, - text: Optional[str], + text: str | None, *, httpx_error: httpx.HTTPStatusError, - error_descriptors: List[DevOpsAPIErrorDescriptor], + error_descriptors: list[DevOpsAPIErrorDescriptor], ) -> None: DataAPIException.__init__(self, text) httpx.HTTPStatusError.__init__( @@ -87,7 +87,7 @@ def from_httpx_error( ) -> DevOpsAPIHttpException: """Parse a httpx status error into this exception.""" - raw_response: Dict[str, Any] + raw_response: dict[str, Any] # the attempt to extract a response structure cannot afford failure. try: raw_response = httpx_error.response.json() @@ -128,16 +128,16 @@ class DevOpsAPITimeoutException(DevOpsAPIException): text: str timeout_type: str - endpoint: Optional[str] - raw_payload: Optional[str] + endpoint: str | None + raw_payload: str | None def __init__( self, text: str, *, timeout_type: str, - endpoint: Optional[str], - raw_payload: Optional[str], + endpoint: str | None, + raw_payload: str | None, ) -> None: super().__init__(text) self.text = text @@ -160,11 +160,11 @@ class DevOpsAPIErrorDescriptor: attributes: a dict with any further key-value pairs returned by the API. """ - id: Optional[int] - message: Optional[str] - attributes: Dict[str, Any] + id: int | None + message: str | None + attributes: dict[str, Any] - def __init__(self, error_dict: Dict[str, Any]) -> None: + def __init__(self, error_dict: dict[str, Any]) -> None: self.id = error_dict.get("ID") self.message = error_dict.get("message") self.attributes = { @@ -184,12 +184,12 @@ class DevOpsAPIFaultyResponseException(DevOpsAPIException): """ text: str - raw_response: Optional[Dict[str, Any]] + raw_response: dict[str, Any] | None def __init__( self, text: str, - raw_response: Optional[Dict[str, Any]], + raw_response: dict[str, Any] | None, ) -> None: super().__init__(text) self.text = text @@ -208,16 +208,16 @@ class DevOpsAPIResponseException(DevOpsAPIException): returned by the API in the response. """ - text: Optional[str] - command: Optional[Dict[str, Any]] - error_descriptors: List[DevOpsAPIErrorDescriptor] + text: str | None + command: dict[str, Any] | None + error_descriptors: list[DevOpsAPIErrorDescriptor] def __init__( self, - text: Optional[str] = None, + text: str | None = None, *, - command: Optional[Dict[str, Any]] = None, - error_descriptors: List[DevOpsAPIErrorDescriptor] = [], + command: dict[str, Any] | None = None, + error_descriptors: list[DevOpsAPIErrorDescriptor] = [], ) -> None: super().__init__(text or self.__class__.__name__) self.text = text @@ -226,8 +226,8 @@ def __init__( @staticmethod def from_response( - command: Optional[Dict[str, Any]], - raw_response: Dict[str, Any], + command: dict[str, Any] | None, + raw_response: dict[str, Any], ) -> DevOpsAPIResponseException: """Parse a raw response from the API into this exception.""" @@ -263,13 +263,13 @@ class DataAPIErrorDescriptor: attributes: a dict with any further key-value pairs returned by the API. """ - title: Optional[str] - error_code: Optional[str] - message: Optional[str] - family: Optional[str] - scope: Optional[str] - id: Optional[str] - attributes: Dict[str, Any] + title: str | None + error_code: str | None + message: str | None + family: str | None + scope: str | None + id: str | None + attributes: dict[str, Any] _known_dict_fields = { "title", @@ -280,7 +280,7 @@ class DataAPIErrorDescriptor: "id", } - def __init__(self, error_dict: Dict[str, str]) -> None: + def __init__(self, error_dict: dict[str, str]) -> None: self.title = error_dict.get("title") self.error_code = error_dict.get("errorCode") self.message = error_dict.get("message") @@ -322,9 +322,9 @@ class DataAPIDetailedErrorDescriptor: raw_response: the full API response in the form of a dict. """ - error_descriptors: List[DataAPIErrorDescriptor] - command: Optional[Dict[str, Any]] - raw_response: Dict[str, Any] + error_descriptors: list[DataAPIErrorDescriptor] + command: dict[str, Any] | None + raw_response: dict[str, Any] class DataAPIException(ValueError): @@ -356,15 +356,15 @@ class DataAPIHttpException(DataAPIException, httpx.HTTPStatusError): found in the response. """ - text: Optional[str] - error_descriptors: List[DataAPIErrorDescriptor] + text: str | None + error_descriptors: list[DataAPIErrorDescriptor] def __init__( self, - text: Optional[str], + text: str | None, *, httpx_error: httpx.HTTPStatusError, - error_descriptors: List[DataAPIErrorDescriptor], + error_descriptors: list[DataAPIErrorDescriptor], ) -> None: DataAPIException.__init__(self, text) httpx.HTTPStatusError.__init__( @@ -388,7 +388,7 @@ def from_httpx_error( ) -> DataAPIHttpException: """Parse a httpx status error into this exception.""" - raw_response: Dict[str, Any] + raw_response: dict[str, Any] # the attempt to extract a response structure cannot afford failure. try: raw_response = httpx_error.response.json() @@ -431,16 +431,16 @@ class DataAPITimeoutException(DataAPIException): text: str timeout_type: str - endpoint: Optional[str] - raw_payload: Optional[str] + endpoint: str | None + raw_payload: str | None def __init__( self, text: str, *, timeout_type: str, - endpoint: Optional[str], - raw_payload: Optional[str], + endpoint: str | None, + raw_payload: str | None, ) -> None: super().__init__(text) self.text = text @@ -579,12 +579,12 @@ class DataAPIFaultyResponseException(DataAPIException): """ text: str - raw_response: Optional[Dict[str, Any]] + raw_response: dict[str, Any] | None def __init__( self, text: str, - raw_response: Optional[Dict[str, Any]], + raw_response: dict[str, Any] | None, ) -> None: super().__init__(text) self.text = text @@ -614,16 +614,16 @@ class DataAPIResponseException(DataAPIException): has a single element. """ - text: Optional[str] - error_descriptors: List[DataAPIErrorDescriptor] - detailed_error_descriptors: List[DataAPIDetailedErrorDescriptor] + text: str | None + error_descriptors: list[DataAPIErrorDescriptor] + detailed_error_descriptors: list[DataAPIDetailedErrorDescriptor] def __init__( self, - text: Optional[str], + text: str | None, *, - error_descriptors: List[DataAPIErrorDescriptor], - detailed_error_descriptors: List[DataAPIDetailedErrorDescriptor], + error_descriptors: list[DataAPIErrorDescriptor], + detailed_error_descriptors: list[DataAPIDetailedErrorDescriptor], ) -> None: super().__init__(text) self.text = text @@ -633,8 +633,8 @@ def __init__( @classmethod def from_response( cls, - command: Optional[Dict[str, Any]], - raw_response: Dict[str, Any], + command: dict[str, Any] | None, + raw_response: dict[str, Any], **kwargs: Any, ) -> DataAPIResponseException: """Parse a raw response from the API into this exception.""" @@ -648,13 +648,13 @@ def from_response( @classmethod def from_responses( cls, - commands: List[Optional[Dict[str, Any]]], - raw_responses: List[Dict[str, Any]], + commands: list[dict[str, Any] | None], + raw_responses: list[dict[str, Any]], **kwargs: Any, ) -> DataAPIResponseException: """Parse a list of raw responses from the API into this exception.""" - detailed_error_descriptors: List[DataAPIDetailedErrorDescriptor] = [] + detailed_error_descriptors: list[DataAPIDetailedErrorDescriptor] = [] for command, raw_response in zip(commands, raw_responses): if raw_response.get("errors", []): error_descriptors = [ @@ -838,13 +838,13 @@ class BulkWriteException(DataAPIResponseException): """ partial_result: BulkWriteResult - exceptions: List[DataAPIResponseException] + exceptions: list[DataAPIResponseException] def __init__( self, - text: Optional[str], + text: str | None, partial_result: BulkWriteResult, - exceptions: List[DataAPIResponseException], + exceptions: list[DataAPIResponseException], *pargs: Any, **kwargs: Any, ) -> None: @@ -916,7 +916,7 @@ def to_devopsapi_timeout_exception( ) -def base_timeout_info(max_time_ms: Optional[int]) -> Union[TimeoutInfo, None]: +def base_timeout_info(max_time_ms: int | None) -> TimeoutInfo | None: if max_time_ms is not None: return {"base": max_time_ms / 1000.0} else: @@ -937,12 +937,12 @@ class MultiCallTimeoutManager: deadline_ms: optional deadline in milliseconds (computed by the class). """ - overall_max_time_ms: Optional[int] + overall_max_time_ms: int | None started_ms: int = -1 - deadline_ms: Optional[int] + deadline_ms: int | None def __init__( - self, overall_max_time_ms: Optional[int], dev_ops_api: bool = False + self, overall_max_time_ms: int | None, dev_ops_api: bool = False ) -> None: self.started_ms = int(time.time() * 1000) self.overall_max_time_ms = overall_max_time_ms @@ -952,7 +952,7 @@ def __init__( else: self.deadline_ms = None - def remaining_timeout_ms(self) -> Union[int, None]: + def remaining_timeout_ms(self) -> int | None: """ Ensure the deadline, if any, is not yet in the past. If it is, raise an appropriate timeout error. @@ -981,7 +981,7 @@ def remaining_timeout_ms(self) -> Union[int, None]: else: return None - def remaining_timeout_info(self) -> Union[TimeoutInfo, None]: + def remaining_timeout_info(self) -> TimeoutInfo | None: """ Ensure the deadline, if any, is not yet in the past. If it is, raise an appropriate timeout error. diff --git a/astrapy/info.py b/astrapy/info.py index 7b9222b6..b72eeee8 100644 --- a/astrapy/info.py +++ b/astrapy/info.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any @dataclass @@ -52,11 +52,11 @@ class DatabaseInfo: id: str region: str - keyspace: Optional[str] - namespace: Optional[str] + keyspace: str | None + namespace: str | None name: str environment: str - raw_info: Optional[Dict[str, Any]] + raw_info: dict[str, Any] | None @dataclass @@ -92,8 +92,8 @@ class AdminDatabaseInfo: """ info: DatabaseInfo - available_actions: Optional[List[str]] - cost: Dict[str, Any] + available_actions: list[str] | None + cost: dict[str, Any] cqlsh_url: str creation_time: str data_endpoint_url: str @@ -101,14 +101,14 @@ class AdminDatabaseInfo: graphql_url: str id: str last_usage_time: str - metrics: Dict[str, Any] + metrics: dict[str, Any] observed_status: str org_id: str owner_id: str status: str - storage: Dict[str, Any] + storage: dict[str, Any] termination_time: str - raw_info: Optional[Dict[str, Any]] + raw_info: dict[str, Any] | None @dataclass @@ -145,15 +145,15 @@ class CollectionDefaultIDOptions: default_id_type: str - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return {"type": self.default_id_type} @staticmethod def from_dict( - raw_dict: Optional[Dict[str, Any]] - ) -> Optional[CollectionDefaultIDOptions]: + raw_dict: dict[str, Any] | None + ) -> CollectionDefaultIDOptions | None: """ Create an instance of CollectionDefaultIDOptions from a dictionary such as one from the Data API. @@ -180,12 +180,12 @@ class CollectionVectorServiceOptions: in the vector service options. """ - provider: Optional[str] - model_name: Optional[str] - authentication: Optional[Dict[str, Any]] = None - parameters: Optional[Dict[str, Any]] = None + provider: str | None + model_name: str | None + authentication: dict[str, Any] | None = None + parameters: dict[str, Any] | None = None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -201,8 +201,8 @@ def as_dict(self) -> Dict[str, Any]: @staticmethod def from_dict( - raw_dict: Optional[Dict[str, Any]] - ) -> Optional[CollectionVectorServiceOptions]: + raw_dict: dict[str, Any] | None + ) -> CollectionVectorServiceOptions | None: """ Create an instance of CollectionVectorServiceOptions from a dictionary such as one from the Data API. @@ -233,11 +233,11 @@ class CollectionVectorOptions: service is configured for the collection. """ - dimension: Optional[int] - metric: Optional[str] - service: Optional[CollectionVectorServiceOptions] + dimension: int | None + metric: str | None + service: CollectionVectorServiceOptions | None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -252,8 +252,8 @@ def as_dict(self) -> Dict[str, Any]: @staticmethod def from_dict( - raw_dict: Optional[Dict[str, Any]] - ) -> Optional[CollectionVectorOptions]: + raw_dict: dict[str, Any] | None + ) -> CollectionVectorOptions | None: """ Create an instance of CollectionVectorOptions from a dictionary such as one from the Data API. @@ -284,10 +284,10 @@ class CollectionOptions: raw_options: the raw response from the Data API for the collection configuration. """ - vector: Optional[CollectionVectorOptions] - indexing: Optional[Dict[str, Any]] - default_id: Optional[CollectionDefaultIDOptions] - raw_options: Optional[Dict[str, Any]] + vector: CollectionVectorOptions | None + indexing: dict[str, Any] | None + default_id: CollectionDefaultIDOptions | None + raw_options: dict[str, Any] | None def __repr__(self) -> str: not_null_pieces = [ @@ -310,7 +310,7 @@ def __repr__(self) -> str: ] return f"{self.__class__.__name__}({', '.join(not_null_pieces)})" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -325,17 +325,17 @@ def as_dict(self) -> Dict[str, Any]: if v is not None } - def flatten(self) -> Dict[str, Any]: + def flatten(self) -> dict[str, Any]: """ Recast this object as a flat key-value pair suitable for use as kwargs in a create_collection method call (including recasts). """ - _dimension: Optional[int] - _metric: Optional[str] - _indexing: Optional[Dict[str, Any]] - _service: Optional[Dict[str, Any]] - _default_id_type: Optional[str] + _dimension: int | None + _metric: str | None + _indexing: dict[str, Any] | None + _service: dict[str, Any] | None + _default_id_type: str | None if self.vector is not None: _dimension = self.vector.dimension _metric = self.vector.metric @@ -366,7 +366,7 @@ def flatten(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> CollectionOptions: + def from_dict(raw_dict: dict[str, Any]) -> CollectionOptions: """ Create an instance of CollectionOptions from a dictionary such as one from the Data API. @@ -394,7 +394,7 @@ class CollectionDescriptor: name: str options: CollectionOptions - raw_descriptor: Optional[Dict[str, Any]] + raw_descriptor: dict[str, Any] | None def __repr__(self) -> str: not_null_pieces = [ @@ -408,7 +408,7 @@ def __repr__(self) -> str: ] return f"{self.__class__.__name__}({', '.join(not_null_pieces)})" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """ Recast this object into a dictionary. Empty `options` will not be returned at all. @@ -423,7 +423,7 @@ def as_dict(self) -> Dict[str, Any]: if v } - def flatten(self) -> Dict[str, Any]: + def flatten(self) -> dict[str, Any]: """ Recast this object as a flat key-value pair suitable for use as kwargs in a create_collection method call (including recasts). @@ -435,7 +435,7 @@ def flatten(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> CollectionDescriptor: + def from_dict(raw_dict: dict[str, Any]) -> CollectionDescriptor: """ Create an instance of CollectionDescriptor from a dictionary such as one from the Data API. @@ -464,18 +464,18 @@ class EmbeddingProviderParameter: """ default_value: Any - display_name: Optional[str] - help: Optional[str] - hint: Optional[str] + display_name: str | None + help: str | None + hint: str | None name: str required: bool parameter_type: str - validation: Dict[str, Any] + validation: dict[str, Any] def __repr__(self) -> str: return f"EmbeddingProviderParameter(name='{self.name}')" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -494,7 +494,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderParameter: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderParameter: """ Create an instance of EmbeddingProviderParameter from a dictionary such as one from the Data API. @@ -543,13 +543,13 @@ class EmbeddingProviderModel: """ name: str - parameters: List[EmbeddingProviderParameter] - vector_dimension: Optional[int] + parameters: list[EmbeddingProviderParameter] + vector_dimension: int | None def __repr__(self) -> str: return f"EmbeddingProviderModel(name='{self.name}')" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -559,7 +559,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderModel: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderModel: """ Create an instance of EmbeddingProviderModel from a dictionary such as one from the Data API. @@ -606,7 +606,7 @@ class EmbeddingProviderToken: def __repr__(self) -> str: return f"EmbeddingProviderToken('{self.accepted}')" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -615,7 +615,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderToken: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderToken: """ Create an instance of EmbeddingProviderToken from a dictionary such as one from the Data API. @@ -650,7 +650,7 @@ class EmbeddingProviderAuthentication: """ enabled: bool - tokens: List[EmbeddingProviderToken] + tokens: list[EmbeddingProviderToken] def __repr__(self) -> str: return ( @@ -658,7 +658,7 @@ def __repr__(self) -> str: f"tokens={','.join(str(token) for token in self.tokens)})" ) - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -667,7 +667,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderAuthentication: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderAuthentication: """ Create an instance of EmbeddingProviderAuthentication from a dictionary such as one from the Data API. @@ -714,13 +714,13 @@ class EmbeddingProvider: def __repr__(self) -> str: return f"EmbeddingProvider(display_name='{self.display_name}', models={self.models})" - display_name: Optional[str] - models: List[EmbeddingProviderModel] - parameters: List[EmbeddingProviderParameter] - supported_authentication: Dict[str, EmbeddingProviderAuthentication] - url: Optional[str] + display_name: str | None + models: list[EmbeddingProviderModel] + parameters: list[EmbeddingProviderParameter] + supported_authentication: dict[str, EmbeddingProviderAuthentication] + url: str | None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -735,7 +735,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProvider: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProvider: """ Create an instance of EmbeddingProvider from a dictionary such as one from the Data API. @@ -788,10 +788,10 @@ def __repr__(self) -> str: f"{', '.join(sorted(self.embedding_providers.keys()))})" ) - embedding_providers: Dict[str, EmbeddingProvider] - raw_info: Optional[Dict[str, Any]] + embedding_providers: dict[str, EmbeddingProvider] + raw_info: dict[str, Any] | None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -802,7 +802,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> FindEmbeddingProvidersResult: + def from_dict(raw_dict: dict[str, Any]) -> FindEmbeddingProvidersResult: """ Create an instance of FindEmbeddingProvidersResult from a dictionary such as one from the Data API. diff --git a/astrapy/meta.py b/astrapy/meta.py index 470d4409..6d704997 100644 --- a/astrapy/meta.py +++ b/astrapy/meta.py @@ -15,7 +15,7 @@ from __future__ import annotations import warnings -from typing import Any, Optional +from typing import Any from deprecation import DeprecatedWarning @@ -70,9 +70,9 @@ def check_deprecated_vector_ize( def check_namespace_keyspace( - keyspace: Optional[str], - namespace: Optional[str], -) -> Optional[str]: + keyspace: str | None, + namespace: str | None, +) -> str | None: # normalize the two aliased parameter names, raising deprecation # when needed and an error if both parameter supplied. # The returned value is the final one for the parameter. @@ -104,9 +104,9 @@ def check_namespace_keyspace( def check_update_db_namespace_keyspace( - update_db_keyspace: Optional[bool], - update_db_namespace: Optional[bool], -) -> Optional[bool]: + update_db_keyspace: bool | None, + update_db_namespace: bool | None, +) -> bool | None: # normalize the two aliased parameter names, raising deprecation # when needed and an error if both parameter supplied. # The returned value is the final one for the parameter. diff --git a/astrapy/operations.py b/astrapy/operations.py index c7997622..cb9de825 100644 --- a/astrapy/operations.py +++ b/astrapy/operations.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import reduce -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Iterable from astrapy.collection import AsyncCollection, Collection from astrapy.constants import DocumentType, SortType, VectorType @@ -31,7 +31,7 @@ ) -def reduce_bulk_write_results(results: List[BulkWriteResult]) -> BulkWriteResult: +def reduce_bulk_write_results(results: list[BulkWriteResult]) -> BulkWriteResult: """ Reduce a list of bulk write results into a single one. @@ -79,7 +79,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: ... @@ -100,15 +100,15 @@ class InsertOne(BaseOperation): """ document: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] + vector: VectorType | None + vectorize: str | None def __init__( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, + vector: VectorType | None = None, + vectorize: str | None = None, ) -> None: self.document = document check_deprecated_vector_ize( @@ -124,7 +124,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -165,21 +165,21 @@ class InsertMany(BaseOperation): """ documents: Iterable[DocumentType] - vectors: Optional[Iterable[Optional[VectorType]]] - vectorize: Optional[Iterable[Optional[str]]] + vectors: Iterable[VectorType | None] | None + vectorize: Iterable[str | None] | None ordered: bool - chunk_size: Optional[int] - concurrency: Optional[int] + chunk_size: int | None + concurrency: int | None def __init__( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = True, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, ) -> None: self.documents = documents self.ordered = ordered @@ -198,7 +198,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -241,21 +241,21 @@ class UpdateOne(BaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + update: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -275,7 +275,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -309,14 +309,14 @@ class UpdateMany(BaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] + filter: dict[str, Any] + update: dict[str, Any] upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, ) -> None: @@ -328,7 +328,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -368,21 +368,21 @@ class ReplaceOne(BaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] + filter: dict[str, Any] replacement: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -402,7 +402,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -443,18 +443,18 @@ class DeleteOne(BaseOperation): sort: controls ordering of results, hence which document is affected. """ - filter: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, ) -> None: self.filter = filter check_deprecated_vector_ize( @@ -471,7 +471,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -501,11 +501,11 @@ class DeleteMany(BaseOperation): filter: a filter condition to select target documents. """ - filter: Dict[str, Any] + filter: dict[str, Any] def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], ) -> None: self.filter = filter @@ -513,7 +513,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -540,7 +540,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: ... @@ -561,15 +561,15 @@ class AsyncInsertOne(AsyncBaseOperation): """ document: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] + vector: VectorType | None + vectorize: str | None def __init__( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, + vector: VectorType | None = None, + vectorize: str | None = None, ) -> None: self.document = document check_deprecated_vector_ize( @@ -585,7 +585,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -626,21 +626,21 @@ class AsyncInsertMany(AsyncBaseOperation): """ documents: Iterable[DocumentType] - vectors: Optional[Iterable[Optional[VectorType]]] - vectorize: Optional[Iterable[Optional[str]]] + vectors: Iterable[VectorType | None] | None + vectorize: Iterable[str | None] | None ordered: bool - chunk_size: Optional[int] - concurrency: Optional[int] + chunk_size: int | None + concurrency: int | None def __init__( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = True, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, ) -> None: self.documents = documents check_deprecated_vector_ize( @@ -659,7 +659,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -702,21 +702,21 @@ class AsyncUpdateOne(AsyncBaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + update: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -736,7 +736,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -770,14 +770,14 @@ class AsyncUpdateMany(AsyncBaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] + filter: dict[str, Any] + update: dict[str, Any] upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, ) -> None: @@ -789,7 +789,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -829,21 +829,21 @@ class AsyncReplaceOne(AsyncBaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] + filter: dict[str, Any] replacement: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -863,7 +863,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -904,18 +904,18 @@ class AsyncDeleteOne(AsyncBaseOperation): sort: controls ordering of results, hence which document is affected. """ - filter: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, ) -> None: self.filter = filter check_deprecated_vector_ize( @@ -932,7 +932,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -962,11 +962,11 @@ class AsyncDeleteMany(AsyncBaseOperation): filter: a filter condition to select target documents. """ - filter: Dict[str, Any] + filter: dict[str, Any] def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], ) -> None: self.filter = filter @@ -974,7 +974,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. diff --git a/astrapy/request_tools.py b/astrapy/request_tools.py index 8d67812a..3a1a1be6 100644 --- a/astrapy/request_tools.py +++ b/astrapy/request_tools.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, TypedDict, Union +from typing import Any, TypedDict, Union import httpx @@ -27,9 +27,9 @@ def log_httpx_request( http_method: str, full_url: str, - request_params: Optional[Dict[str, Any]], - redacted_request_headers: Dict[str, str], - payload: Optional[Dict[str, Any]], + request_params: dict[str, Any] | None, + redacted_request_headers: dict[str, str], + payload: dict[str, Any] | None, ) -> None: """ Log the details of an HTTP request for debugging purposes. @@ -79,7 +79,7 @@ class TimeoutInfo(TypedDict, total=False): TimeoutInfoWideType = Union[TimeoutInfo, float, None] -def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> Union[httpx.Timeout, None]: +def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> httpx.Timeout | None: if timeout_info is None: return None if isinstance(timeout_info, float) or isinstance(timeout_info, int): diff --git a/astrapy/results.py b/astrapy/results.py index 7f150af0..e2c5715c 100644 --- a/astrapy/results.py +++ b/astrapy/results.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any @dataclass @@ -30,9 +30,9 @@ class OperationResult(ABC): list of raw responses can contain exactly one or a number of items. """ - raw_results: List[Dict[str, Any]] + raw_results: list[dict[str, Any]] - def _piecewise_repr(self, pieces: List[Optional[str]]) -> str: + def _piecewise_repr(self, pieces: list[str | None]) -> str: return f"{self.__class__.__name__}({', '.join(pc for pc in pieces if pc)})" @abstractmethod @@ -51,7 +51,7 @@ class DeleteResult(OperationResult): list of raw responses can contain exactly one or a number of items. """ - deleted_count: Optional[int] + deleted_count: int | None def __repr__(self) -> str: return self._piecewise_repr( @@ -115,7 +115,7 @@ class InsertManyResult(OperationResult): inserted_ids: list of the IDs of the inserted documents """ - inserted_ids: List[Any] + inserted_ids: list[Any] def __repr__(self) -> str: return self._piecewise_repr( @@ -153,7 +153,7 @@ class UpdateResult(OperationResult): """ - update_info: Dict[str, Any] + update_info: dict[str, Any] def __repr__(self) -> str: return self._piecewise_repr( @@ -201,13 +201,13 @@ class BulkWriteResult: upserted_ids: a (sparse) map from indices to ID of the upserted document """ - bulk_api_results: Dict[int, List[Dict[str, Any]]] - deleted_count: Optional[int] + bulk_api_results: dict[int, list[dict[str, Any]]] + deleted_count: int | None inserted_count: int matched_count: int modified_count: int upserted_count: int - upserted_ids: Dict[int, Any] + upserted_ids: dict[int, Any] def __repr__(self) -> str: pieces = [ diff --git a/astrapy/transform_payload.py b/astrapy/transform_payload.py index 9c2f78a1..e3bda7f3 100644 --- a/astrapy/transform_payload.py +++ b/astrapy/transform_payload.py @@ -16,13 +16,13 @@ import datetime import time -from typing import Any, Dict, Iterable, List, Union, cast +from typing import Any, Dict, Iterable, cast from astrapy.constants import DocumentType from astrapy.ids import UUID, ObjectId -def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: +def convert_vector_to_floats(vector: Iterable[Any]) -> list[float]: """ Convert a vector of strings to a vector of floats. @@ -46,36 +46,36 @@ def is_list_of_floats(vector: Iterable[Any]) -> bool: def convert_to_ejson_date_object( - date_value: Union[datetime.date, datetime.datetime] -) -> Dict[str, int]: + date_value: datetime.date | datetime.datetime, +) -> dict[str, int]: return {"$date": int(time.mktime(date_value.timetuple()) * 1000)} -def convert_to_ejson_uuid_object(uuid_value: UUID) -> Dict[str, str]: +def convert_to_ejson_uuid_object(uuid_value: UUID) -> dict[str, str]: return {"$uuid": str(uuid_value)} -def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> Dict[str, str]: +def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> dict[str, str]: return {"$objectId": str(objectid_value)} def convert_ejson_date_object_to_datetime( - date_object: Dict[str, int] + date_object: dict[str, int], ) -> datetime.datetime: return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0) -def convert_ejson_uuid_object_to_uuid(uuid_object: Dict[str, str]) -> UUID: +def convert_ejson_uuid_object_to_uuid(uuid_object: dict[str, str]) -> UUID: return UUID(uuid_object["$uuid"]) def convert_ejson_objectid_object_to_objectid( - objectid_object: Dict[str, str] + objectid_object: dict[str, str], ) -> ObjectId: return ObjectId(objectid_object["$objectId"]) -def normalize_payload_value(path: List[str], value: Any) -> Any: +def normalize_payload_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ @@ -104,9 +104,7 @@ def normalize_payload_value(path: List[str], value: Any) -> Any: return value -def normalize_for_api( - payload: Union[Dict[str, Any], None] -) -> Union[Dict[str, Any], None]: +def normalize_for_api(payload: dict[str, Any] | None) -> dict[str, Any] | None: """ Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key @@ -125,7 +123,7 @@ def normalize_for_api( return payload -def restore_response_value(path: List[str], value: Any) -> Any: +def restore_response_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ diff --git a/astrapy/user_agents.py b/astrapy/user_agents.py index b17aee60..9c9572a5 100644 --- a/astrapy/user_agents.py +++ b/astrapy/user_agents.py @@ -16,17 +16,16 @@ from importlib import metadata from importlib.metadata import PackageNotFoundError -from typing import List, Optional, Tuple from astrapy import __version__ -def detect_astrapy_user_agent() -> Tuple[Optional[str], Optional[str]]: +def detect_astrapy_user_agent() -> tuple[str | None, str | None]: package_name = __name__.split(".")[0] return (package_name, __version__) -def detect_ragstack_user_agent() -> Tuple[Optional[str], Optional[str]]: +def detect_ragstack_user_agent() -> tuple[str | None, str | None]: try: ragstack_meta = metadata.metadata("ragstack-ai") if ragstack_meta: @@ -38,8 +37,8 @@ def detect_ragstack_user_agent() -> Tuple[Optional[str], Optional[str]]: def compose_user_agent_string( - caller_name: Optional[str], caller_version: Optional[str] -) -> Optional[str]: + caller_name: str | None, caller_version: str | None +) -> str | None: if caller_name: if caller_version: return f"{caller_name}/{caller_version}" @@ -49,9 +48,7 @@ def compose_user_agent_string( return None -def compose_full_user_agent( - callers: List[Tuple[Optional[str], Optional[str]]] -) -> Optional[str]: +def compose_full_user_agent(callers: list[tuple[str | None, str | None]]) -> str | None: user_agent_strings = [ ua_string for ua_string in ( diff --git a/poetry.lock b/poetry.lock index ce8a12cc..2881d574 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "anyio" @@ -1179,28 +1179,29 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.2.2" +version = "0.6.6" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0a9efb032855ffb3c21f6405751d5e147b0c6b631e3ca3f6b20f917572b97eb6"}, - {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d450b7fbff85913f866a5384d8912710936e2b96da74541c82c1b458472ddb39"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecd46e3106850a5c26aee114e562c329f9a1fbe9e4821b008c4404f64ff9ce73"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e22676a5b875bd72acd3d11d5fa9075d3a5f53b877fe7b4793e4673499318ba"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1695700d1e25a99d28f7a1636d85bafcc5030bba9d0578c0781ba1790dbcf51c"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:b0c232af3d0bd8f521806223723456ffebf8e323bd1e4e82b0befb20ba18388e"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f63d96494eeec2fc70d909393bcd76c69f35334cdbd9e20d089fb3f0640216ca"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a61ea0ff048e06de273b2e45bd72629f470f5da8f71daf09fe481278b175001"}, - {file = "ruff-0.2.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1439c8f407e4f356470e54cdecdca1bd5439a0673792dbe34a2b0a551a2fe3"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:940de32dc8853eba0f67f7198b3e79bc6ba95c2edbfdfac2144c8235114d6726"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0c126da55c38dd917621552ab430213bdb3273bb10ddb67bc4b761989210eb6e"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3b65494f7e4bed2e74110dac1f0d17dc8e1f42faaa784e7c58a98e335ec83d7e"}, - {file = "ruff-0.2.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1ec49be4fe6ddac0503833f3ed8930528e26d1e60ad35c2446da372d16651ce9"}, - {file = "ruff-0.2.2-py3-none-win32.whl", hash = "sha256:d920499b576f6c68295bc04e7b17b6544d9d05f196bb3aac4358792ef6f34325"}, - {file = "ruff-0.2.2-py3-none-win_amd64.whl", hash = "sha256:cc9a91ae137d687f43a44c900e5d95e9617cb37d4c989e462980ba27039d239d"}, - {file = "ruff-0.2.2-py3-none-win_arm64.whl", hash = "sha256:c9d15fc41e6054bfc7200478720570078f0b41c9ae4f010bcc16bd6f4d1aacdd"}, - {file = "ruff-0.2.2.tar.gz", hash = "sha256:e62ed7f36b3068a30ba39193a14274cd706bc486fad521276458022f7bccb31d"}, + {file = "ruff-0.6.6-py3-none-linux_armv6l.whl", hash = "sha256:f5bc5398457484fc0374425b43b030e4668ed4d2da8ee7fdda0e926c9f11ccfb"}, + {file = "ruff-0.6.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:515a698254c9c47bb84335281a170213b3ee5eb47feebe903e1be10087a167ce"}, + {file = "ruff-0.6.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6bb1b4995775f1837ab70f26698dd73852bbb82e8f70b175d2713c0354fe9182"}, + {file = "ruff-0.6.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69c546f412dfae8bb9cc4f27f0e45cdd554e42fecbb34f03312b93368e1cd0a6"}, + {file = "ruff-0.6.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59627e97364329e4eae7d86fa7980c10e2b129e2293d25c478ebcb861b3e3fd6"}, + {file = "ruff-0.6.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:94c3f78c3d32190aafbb6bc5410c96cfed0a88aadb49c3f852bbc2aa9783a7d8"}, + {file = "ruff-0.6.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:704da526c1e137f38c8a067a4a975fe6834b9f8ba7dbc5fd7503d58148851b8f"}, + {file = "ruff-0.6.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:efeede5815a24104579a0f6320660536c5ffc1c91ae94f8c65659af915fb9de9"}, + {file = "ruff-0.6.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e368aef0cc02ca3593eae2fb8186b81c9c2b3f39acaaa1108eb6b4d04617e61f"}, + {file = "ruff-0.6.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2653fc3b2a9315bd809725c88dd2446550099728d077a04191febb5ea79a4f79"}, + {file = "ruff-0.6.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:bb858cd9ce2d062503337c5b9784d7b583bcf9d1a43c4df6ccb5eab774fbafcb"}, + {file = "ruff-0.6.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:488f8e15c01ea9afb8c0ba35d55bd951f484d0c1b7c5fd746ce3c47ccdedce68"}, + {file = "ruff-0.6.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:aefb0bd15f1cfa4c9c227b6120573bb3d6c4ee3b29fb54a5ad58f03859bc43c6"}, + {file = "ruff-0.6.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a4c0698cc780bcb2c61496cbd56b6a3ac0ad858c966652f7dbf4ceb029252fbe"}, + {file = "ruff-0.6.6-py3-none-win32.whl", hash = "sha256:aadf81ddc8ab5b62da7aae78a91ec933cbae9f8f1663ec0325dae2c364e4ad84"}, + {file = "ruff-0.6.6-py3-none-win_amd64.whl", hash = "sha256:0adb801771bc1f1b8cf4e0a6fdc30776e7c1894810ff3b344e50da82ef50eeb1"}, + {file = "ruff-0.6.6-py3-none-win_arm64.whl", hash = "sha256:4b4d32c137bc781c298964dd4e52f07d6f7d57c03eae97a72d97856844aa510a"}, + {file = "ruff-0.6.6.tar.gz", hash = "sha256:0fc030b6fd14814d69ac0196396f6761921bd20831725c7361e1b8100b818034"}, ] [[package]] @@ -1449,4 +1450,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "464b343515c7a7c6124025a4bf206ff43839d474abf655e0e36a13e8b3ef5ed8" +content-hash = "5edc52ee4a8462a10aa6ab130900a9628494dcc4febe935ad4b6bddb28391759" diff --git a/pyproject.toml b/pyproject.toml index f9902821..47f4f38b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[project] +requires-python = ">=3.8" + [tool.poetry] name = "astrapy" version = "1.5.0" @@ -47,7 +50,7 @@ pytest = "~8.0.0" python-dotenv = "~1.0.1" pytest-httpserver = "~1.0.8" testcontainers = "~3.7.1" -ruff = "~0.2.1" +ruff = "^0.6.6" types-toml = "^0.10.8.7" isort = "^5.13.2" @@ -55,6 +58,9 @@ isort = "^5.13.2" requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "FA", "UP"] + [tool.mypy] disallow_any_generics = true disallow_incomplete_defs = true diff --git a/scripts/astrapy_latest_interface.py b/scripts/astrapy_latest_interface.py index 136d402e..2aecc6be 100644 --- a/scripts/astrapy_latest_interface.py +++ b/scripts/astrapy_latest_interface.py @@ -1,10 +1,9 @@ import os import sys -import astrapy - from dotenv import load_dotenv +import astrapy sys.path.append("../") diff --git a/tests/conftest.py b/tests/conftest.py index e2bf1312..00ccc188 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ import functools import warnings -from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, TypedDict +from typing import Any, Awaitable, Callable, TypedDict import pytest from deprecation import UnsupportedWarning @@ -72,10 +72,10 @@ class DataAPICoreCredentials(TypedDict): class DataAPICredentialsInfo(TypedDict): environment: str region: str - secondary_namespace: Optional[str] + secondary_namespace: str | None -def env_region_from_endpoint(api_endpoint: str) -> Tuple[str, str]: +def env_region_from_endpoint(api_endpoint: str) -> tuple[str, str]: parsed = parse_api_endpoint(api_endpoint) if parsed is not None: return (parsed.environment, parsed.region) @@ -105,10 +105,9 @@ async def test_inner(*args: Any, **kwargs: Any) -> Any: for warning in caught_warnings: if warning.category == UnsupportedWarning: raise AssertionError( - ( - "%s uses a function that should be removed: %s" - % (method, str(warning.message)) - ) + + f"{method} uses a function that should be removed: {str(warning.message)}" + ) return rv @@ -132,17 +131,16 @@ def test_inner(*args: Any, **kwargs: Any) -> Any: for warning in caught_warnings: if warning.category == UnsupportedWarning: raise AssertionError( - ( - "%s uses a function that should be removed: %s" - % (method, str(warning.message)) - ) + + f"{method} uses a function that should be removed: {str(warning.message)}" + ) return rv return test_inner -def clean_nulls_from_dict(in_dict: Dict[str, Any]) -> dict[str, Any]: +def clean_nulls_from_dict(in_dict: dict[str, Any]) -> dict[str, Any]: def _cleand(_in: Any) -> Any: if isinstance(_in, list): diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 7e9710a8..41e9813e 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -19,7 +19,7 @@ from __future__ import annotations import math -from typing import AsyncIterable, Dict, Iterable, List, Optional, Set, TypeVar +from typing import AsyncIterable, Iterable, TypeVar import pytest import pytest_asyncio @@ -111,7 +111,7 @@ def db( if token is None or api_endpoint is None: raise ValueError("Required ASTRA DB configuration is missing") - db_kwargs: Dict[str, str] + db_kwargs: dict[str, str] if data_api_credentials_info["environment"] in {"prod", "dev", "test"}: db_kwargs = {} else: @@ -134,7 +134,7 @@ async def async_db( if token is None or api_endpoint is None: raise ValueError("Required ASTRA DB configuration is missing") - db_kwargs: Dict[str, str] + db_kwargs: dict[str, str] if data_api_credentials_info["environment"] in {"prod", "dev", "test"}: db_kwargs = {} else: @@ -151,14 +151,14 @@ async def async_db( @pytest.fixture(scope="module") def invalid_db( - data_api_core_bad_credentials_kwargs: Dict[str, Optional[str]], + data_api_core_bad_credentials_kwargs: dict[str, str | None], data_api_credentials_info: DataAPICredentialsInfo, ) -> AstraDB: token = data_api_core_bad_credentials_kwargs["token"] api_endpoint = data_api_core_bad_credentials_kwargs["api_endpoint"] namespace = data_api_core_bad_credentials_kwargs.get("namespace") - db_kwargs: Dict[str, str] + db_kwargs: dict[str, str] if data_api_credentials_info["environment"] in {"prod", "dev", "test"}: db_kwargs = {} else: @@ -319,11 +319,11 @@ def pagination_v_collection( INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints N = 200 # must be EVEN - def _mk_vector(index: int, n_total_steps: int) -> List[float]: + def _mk_vector(index: int, n_total_steps: int) -> list[float]: angle = 2 * math.pi * index / n_total_steps return [math.cos(angle), math.sin(angle)] - inserted_ids: Set[str] = set() + inserted_ids: set[str] = set() for i_batch in _batch_iterable(range(N), INSERT_BATCH_SIZE): batch_ids = empty_v_collection.insert_many( documents=[{"_id": str(i), "$vector": _mk_vector(i, N)} for i in i_batch] diff --git a/tests/core/test_async_db_dml.py b/tests/core/test_async_db_dml.py index 83148909..b733686f 100644 --- a/tests/core/test_async_db_dml.py +++ b/tests/core/test_async_db_dml.py @@ -22,7 +22,7 @@ import datetime import logging import uuid -from typing import Any, Dict, Iterable, List, Literal, Optional, Union, cast +from typing import Any, Iterable, List, Literal, cast import pytest @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -def _cleanvec(doc: Dict[str, Any]) -> Dict[str, Any]: +def _cleanvec(doc: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in doc.items() if k != "$vector"} @@ -135,7 +135,7 @@ async def test_find_find_one_projection( sort = {"$vector": query} options = {"limit": 1} - projs: List[Optional[Dict[str, Literal[1]]]] = [ + projs: list[dict[str, Literal[1]] | None] = [ None, {}, {"text": 1}, @@ -344,7 +344,7 @@ async def test_insert_many( ) -> None: _id0 = str(uuid.uuid4()) _id2 = str(uuid.uuid4()) - documents: List[API_DOC] = [ + documents: list[API_DOC] = [ { "_id": _id0, "name": "Abba", @@ -375,7 +375,7 @@ async def test_chunked_insert_many( async_writable_v_collection: AsyncAstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -387,7 +387,7 @@ async def test_chunked_insert_many( for doc_idx, _id in enumerate(_ids0) ] - responses0: List[Union[Dict[str, Any], Exception]] = ( + responses0: list[dict[str, Any] | Exception] = ( await async_writable_v_collection.chunked_insert_many(documents0, chunk_size=3) ) assert responses0 is not None @@ -408,7 +408,7 @@ async def test_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { @@ -458,7 +458,7 @@ async def test_concurrent_chunked_insert_many( async_writable_v_collection: AsyncAstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -493,7 +493,7 @@ async def test_concurrent_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { diff --git a/tests/core/test_async_db_dml_pagination.py b/tests/core/test_async_db_dml_pagination.py index 09ac4878..44aabfe1 100644 --- a/tests/core/test_async_db_dml_pagination.py +++ b/tests/core/test_async_db_dml_pagination.py @@ -19,7 +19,6 @@ from __future__ import annotations import logging -from typing import Optional import pytest @@ -43,7 +42,7 @@ ], ) async def test_find_paginated( - prefetched: Optional[int], + prefetched: int | None, async_pagination_v_collection: AsyncAstraDBCollection, ) -> None: options = {"limit": FIND_LIMIT} diff --git a/tests/core/test_db_dml.py b/tests/core/test_db_dml.py index 3d69f7d9..a7006bd6 100644 --- a/tests/core/test_db_dml.py +++ b/tests/core/test_db_dml.py @@ -23,7 +23,7 @@ import json import logging import uuid -from typing import Any, Dict, Iterable, List, Literal, Optional, Set, cast +from typing import Any, Iterable, List, Literal, cast import httpx import pytest @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) -def _cleanvec(doc: Dict[str, Any]) -> Dict[str, Any]: +def _cleanvec(doc: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in doc.items() if k != "$vector"} @@ -130,14 +130,14 @@ def test_find_find_one_projection( query = [0.2, 0.6] sort = {"$vector": query} options = {"limit": 1} - projs: List[Optional[Dict[str, Literal[1]]]] = [ + projs: list[dict[str, Literal[1]] | None] = [ None, {}, {"text": 1}, {"$vector": 1}, {"text": 1, "$vector": 1}, ] - exp_fieldsets: List[Set[str]] = [ + exp_fieldsets: list[set[str]] = [ {"$vector", "_id", "otherfield", "anotherfield", "text"}, {"$vector", "_id", "otherfield", "anotherfield", "text"}, {"_id", "text"}, @@ -327,7 +327,7 @@ def test_insert_float32(writable_v_collection: AstraDBCollection, N: int = 2) -> def test_insert_many(writable_v_collection: AstraDBCollection) -> None: _id0 = str(uuid.uuid4()) _id2 = str(uuid.uuid4()) - documents: List[API_DOC] = [ + documents: list[API_DOC] = [ { "_id": _id0, "name": "Abba", @@ -358,7 +358,7 @@ def test_chunked_insert_many( writable_v_collection: AstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -389,7 +389,7 @@ def test_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { @@ -439,7 +439,7 @@ def test_concurrent_chunked_insert_many( writable_v_collection: AstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -472,7 +472,7 @@ def test_concurrent_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { diff --git a/tests/core/test_db_dml_pagination.py b/tests/core/test_db_dml_pagination.py index 6d94ecb5..23ff3113 100644 --- a/tests/core/test_db_dml_pagination.py +++ b/tests/core/test_db_dml_pagination.py @@ -20,7 +20,6 @@ import logging import time -from typing import Optional import pytest @@ -44,7 +43,7 @@ ], ) def test_find_paginated( - prefetched: Optional[int], + prefetched: int | None, pagination_v_collection: AstraDBCollection, caplog: pytest.LogCaptureFixture, ) -> None: diff --git a/tests/core/test_ops.py b/tests/core/test_ops.py index b6ada837..2139d15c 100644 --- a/tests/core/test_ops.py +++ b/tests/core/test_ops.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) -def find_new_name(existing: List[str], prefix: str) -> str: +def find_new_name(existing: list[str], prefix: str) -> str: candidate_name = prefix for idx in itertools.count(): candidate_name = f"{prefix}{idx}" diff --git a/tests/idiomatic/integration/test_admin.py b/tests/idiomatic/integration/test_admin.py index e4d56e22..e28d6411 100644 --- a/tests/idiomatic/integration/test_admin.py +++ b/tests/idiomatic/integration/test_admin.py @@ -15,7 +15,7 @@ from __future__ import annotations import time -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import Any, Awaitable, Callable import pytest @@ -37,15 +37,15 @@ PRE_DROP_SAFETY_TIMEOUT = 120 -def admin_test_envs_tokens() -> List[Any]: +def admin_test_envs_tokens() -> list[Any]: """ This actually returns a List of `_pytest.mark.structures.ParameterSet` instances, each wrapping a Tuple[str, Optional[str]] = (env, token) """ - envs_tokens: List[Any] = [] + envs_tokens: list[Any] = [] for admin_env in ADMIN_ENV_LIST: markers = [] - pair: Tuple[str, Optional[str]] + pair: tuple[str, str | None] if ADMIN_ENV_VARIABLE_MAP[admin_env]["token"]: pair = (admin_env, ADMIN_ENV_VARIABLE_MAP[admin_env]["token"]) else: @@ -84,7 +84,7 @@ class TestAdmin: ) @pytest.mark.describe("test of the full tour with AstraDBDatabaseAdmin, sync") def test_astra_db_database_admin_sync( - self, admin_env_token: Tuple[str, str] + self, admin_env_token: tuple[str, str] ) -> None: """ Test plan (it has to be a single giant test to use one DB throughout): @@ -206,7 +206,7 @@ def test_astra_db_database_admin_sync( @pytest.mark.describe( "test of the full tour with AstraDBAdmin and client methods, sync" ) - def test_astra_db_admin_sync(self, admin_env_token: Tuple[str, str]) -> None: + def test_astra_db_admin_sync(self, admin_env_token: tuple[str, str]) -> None: """ Test plan (it has to be a single giant test to use the two DBs throughout): - create client -> get_admin @@ -336,7 +336,7 @@ def _waiter2() -> bool: ) @pytest.mark.describe("test of the full tour with AstraDBDatabaseAdmin, async") async def test_astra_db_database_admin_async( - self, admin_env_token: Tuple[str, str] + self, admin_env_token: tuple[str, str] ) -> None: """ Test plan (it has to be a single giant test to use one DB throughout): @@ -471,7 +471,7 @@ async def _awaiter3() -> bool: @pytest.mark.describe( "test of the full tour with AstraDBAdmin and client methods, async" ) - async def test_astra_db_admin_async(self, admin_env_token: Tuple[str, str]) -> None: + async def test_astra_db_admin_async(self, admin_env_token: tuple[str, str]) -> None: """ Test plan (it has to be a single giant test to use the two DBs throughout): - create client -> get_admin diff --git a/tests/idiomatic/integration/test_dml_async.py b/tests/idiomatic/integration/test_dml_async.py index fbad0d1d..46a1bb3b 100644 --- a/tests/idiomatic/integration/test_dml_async.py +++ b/tests/idiomatic/integration/test_dml_async.py @@ -15,7 +15,7 @@ from __future__ import annotations import datetime -from typing import Any, Dict, List +from typing import Any import pytest @@ -258,7 +258,7 @@ async def test_collection_find_async( Nsor = {"seq": SortDocuments.DESCENDING} Nfil = {"seq": {"$exists": True}} - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] # case 0000 of find-pattern matrix @@ -466,7 +466,7 @@ async def test_collection_cursors_async( document0b = await cursor0b.__anext__() assert "ternary" in document0b - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] # rewinding, slicing and retrieved @@ -544,7 +544,7 @@ async def test_collection_distinct_nonhashable_async( async_empty_collection: AsyncCollection, ) -> None: acol = async_empty_collection - documents: List[Dict[str, Any]] = [ + documents: list[dict[str, Any]] = [ {}, {"f": 1}, {"f": "a"}, @@ -685,13 +685,13 @@ async def test_collection_include_sort_vector_find_async( ) -> None: q_vector = [10, 9] - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] # with empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_e: Dict[str, Any] = ( + sort_cl_e: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" @@ -726,7 +726,7 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]: # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_f: Dict[str, Any] = ( + sort_cl_f: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" diff --git a/tests/idiomatic/integration/test_dml_sync.py b/tests/idiomatic/integration/test_dml_sync.py index 49d7d2c3..fa85feb1 100644 --- a/tests/idiomatic/integration/test_dml_sync.py +++ b/tests/idiomatic/integration/test_dml_sync.py @@ -15,7 +15,7 @@ from __future__ import annotations import datetime -from typing import Any, Dict, List +from typing import Any import pytest @@ -478,7 +478,7 @@ def test_collection_distinct_nonhashable_sync( sync_empty_collection: Collection, ) -> None: col = sync_empty_collection - documents: List[Dict[str, Any]] = [ + documents: list[dict[str, Any]] = [ {}, {"f": 1}, {"f": "a"}, @@ -619,7 +619,7 @@ def test_collection_include_sort_vector_find_sync( # with empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_e: Dict[str, Any] = ( + sort_cl_e: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" @@ -654,7 +654,7 @@ def test_collection_include_sort_vector_find_sync( # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_f: Dict[str, Any] = ( + sort_cl_f: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" diff --git a/tests/idiomatic/integration/test_exceptions_async.py b/tests/idiomatic/integration/test_exceptions_async.py index 7eb851b4..6477f024 100644 --- a/tests/idiomatic/integration/test_exceptions_async.py +++ b/tests/idiomatic/integration/test_exceptions_async.py @@ -14,8 +14,6 @@ from __future__ import annotations -from typing import List - import pytest from astrapy import AsyncCollection, AsyncDatabase @@ -62,8 +60,7 @@ async def test_collection_insert_many_insert_failures_async( self, async_empty_collection: AsyncCollection, ) -> None: - - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] acol = async_empty_collection diff --git a/tests/idiomatic/unit/test_apicommander.py b/tests/idiomatic/unit/test_apicommander.py index ffc28325..11ced3fb 100644 --- a/tests/idiomatic/unit/test_apicommander.py +++ b/tests/idiomatic/unit/test_apicommander.py @@ -16,7 +16,6 @@ import json import logging -from typing import Optional import pytest from pytest_httpserver import HTTPServer @@ -82,7 +81,7 @@ def test_apicommander_request_sync(self, httpserver: HTTPServer) -> None: callers=[("cn", "cv")], ) - def hv_matcher(hk: str, hv: Optional[str], ev: str) -> bool: + def hv_matcher(hk: str, hv: str | None, ev: str) -> bool: if hk == "v": return hv == ev elif hk.lower() == "user-agent": @@ -130,7 +129,7 @@ async def test_apicommander_request_async(self, httpserver: HTTPServer) -> None: callers=[("cn", "cv")], ) - def hv_matcher(hk: str, hv: Optional[str], ev: str) -> bool: + def hv_matcher(hk: str, hv: str | None, ev: str) -> bool: if hk == "v": return hv == ev elif hk.lower() == "user-agent": diff --git a/tests/idiomatic/unit/test_collection_options.py b/tests/idiomatic/unit/test_collection_options.py index f6e09244..8b98bf5e 100644 --- a/tests/idiomatic/unit/test_collection_options.py +++ b/tests/idiomatic/unit/test_collection_options.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any import pytest @@ -27,7 +27,7 @@ @pytest.mark.describe("test of recasting the collection options from the api") def test_recast_api_collection_dict() -> None: - api_coll_descs: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [ + api_coll_descs: list[tuple[dict[str, Any], dict[str, Any]]] = [ # minimal: ( { diff --git a/tests/idiomatic/unit/test_document_extractors.py b/tests/idiomatic/unit/test_document_extractors.py index 985f28e8..3131c67a 100644 --- a/tests/idiomatic/unit/test_document_extractors.py +++ b/tests/idiomatic/unit/test_document_extractors.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any import pytest @@ -57,7 +57,7 @@ def test_dotted_fieldname_document_extractor(self) -> None: } def assert_extracts( - document: Dict[str, Any], key: str, expected: List[Any] + document: dict[str, Any], key: str, expected: list[Any] ) -> None: _extractor = _create_document_key_extractor(key) _extracted = list(_extractor(document)) diff --git a/tests/preprocess_env.py b/tests/preprocess_env.py index 876e9945..244762dd 100644 --- a/tests/preprocess_env.py +++ b/tests/preprocess_env.py @@ -22,7 +22,6 @@ import os import time -from typing import List, Optional from testcontainers.compose import DockerCompose @@ -40,18 +39,18 @@ IS_ASTRA_DB: bool DOCKER_COMPOSE_LOCAL_DATA_API: bool -SECONDARY_NAMESPACE: Optional[str] = None -ASTRA_DB_API_ENDPOINT: Optional[str] = None -ASTRA_DB_APPLICATION_TOKEN: Optional[str] = None -ASTRA_DB_KEYSPACE: Optional[str] = None -LOCAL_DATA_API_USERNAME: Optional[str] = None -LOCAL_DATA_API_PASSWORD: Optional[str] = None -LOCAL_DATA_API_APPLICATION_TOKEN: Optional[str] = None -LOCAL_DATA_API_ENDPOINT: Optional[str] = None -LOCAL_DATA_API_KEYSPACE: Optional[str] = None - -ASTRA_DB_TOKEN_PROVIDER: Optional[TokenProvider] = None -LOCAL_DATA_API_TOKEN_PROVIDER: Optional[TokenProvider] = None +SECONDARY_NAMESPACE: str | None = None +ASTRA_DB_API_ENDPOINT: str | None = None +ASTRA_DB_APPLICATION_TOKEN: str | None = None +ASTRA_DB_KEYSPACE: str | None = None +LOCAL_DATA_API_USERNAME: str | None = None +LOCAL_DATA_API_PASSWORD: str | None = None +LOCAL_DATA_API_APPLICATION_TOKEN: str | None = None +LOCAL_DATA_API_ENDPOINT: str | None = None +LOCAL_DATA_API_KEYSPACE: str | None = None + +ASTRA_DB_TOKEN_PROVIDER: TokenProvider | None = None +LOCAL_DATA_API_TOKEN_PROVIDER: TokenProvider | None = None # idiomatic-related settings if "LOCAL_DATA_API_ENDPOINT" in os.environ: @@ -114,7 +113,6 @@ is_docker_compose_started = False if DOCKER_COMPOSE_LOCAL_DATA_API: if not is_docker_compose_started: - """ Note: this is a trick to invoke `docker compose` as opposed to `docker-compose` while using testcontainers < 4. @@ -133,8 +131,7 @@ """ class RedefineCommandDockerCompose(DockerCompose): - - def docker_compose_command(self) -> List[str]: + def docker_compose_command(self) -> list[str]: docker_compose_cmd = ["docker", "compose"] for file in self.compose_file_names: docker_compose_cmd += ["-f", file] diff --git a/tests/vectorize_idiomatic/conftest.py b/tests/vectorize_idiomatic/conftest.py index f27967d9..9dbf2501 100644 --- a/tests/vectorize_idiomatic/conftest.py +++ b/tests/vectorize_idiomatic/conftest.py @@ -19,7 +19,7 @@ from __future__ import annotations import os -from typing import Any, Dict, Iterable +from typing import Any, Iterable import pytest @@ -59,7 +59,7 @@ def async_database( @pytest.fixture(scope="session") -def service_collection_parameters() -> Iterable[Dict[str, Any]]: +def service_collection_parameters() -> Iterable[dict[str, Any]]: yield { "dimension": 1536, "provider": "openai", @@ -72,7 +72,7 @@ def service_collection_parameters() -> Iterable[Dict[str, Any]]: def sync_service_collection( data_api_credentials_kwargs: DataAPICredentials, sync_database: Database, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> Iterable[Collection]: """ An actual collection on DB, in the main keyspace. diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py b/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py index 9a0c5fd8..cbeb2a15 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any import pytest @@ -36,7 +36,7 @@ class TestVectorizeMethodsAsync: async def test_collection_methods_vectorize_async( self, async_empty_service_collection: AsyncCollection, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: acol = async_empty_service_collection service_vector_dimension = service_collection_parameters["dimension"] @@ -187,12 +187,12 @@ async def test_collection_include_sort_vector_vectorize_find_async( def _is_vector(v: Any) -> bool: return isinstance(v, list) and isinstance(v[0], float) - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_e: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_e: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = async_empty_service_collection.find( @@ -228,7 +228,7 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]: # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_f: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_f: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = async_empty_service_collection.find( @@ -274,7 +274,7 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]: async def test_database_create_collection_dimension_mismatch_failure_async( self, async_database: AsyncDatabase, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: with pytest.raises(DataAPIResponseException): await async_database.create_collection( diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py b/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py index 51e6d092..cf873224 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import pytest @@ -28,7 +28,7 @@ class TestVectorizeMethodsSync: def test_collection_methods_vectorize_sync( self, sync_empty_service_collection: Collection, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: col = sync_empty_service_collection service_vector_dimension = service_collection_parameters["dimension"] @@ -178,7 +178,7 @@ def _is_vector(v: Any) -> bool: for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_e: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_e: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = sync_empty_service_collection.find( @@ -214,7 +214,7 @@ def _is_vector(v: Any) -> bool: # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_f: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_f: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = sync_empty_service_collection.find( @@ -256,7 +256,7 @@ def _is_vector(v: Any) -> bool: def test_database_create_collection_dimension_mismatch_failure_sync( self, sync_database: Database, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: with pytest.raises(DataAPIResponseException): sync_database.create_collection( diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py index cd4e6989..002f67e7 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py @@ -16,7 +16,7 @@ import os import sys -from typing import Any, Dict, List, Union +from typing import Any import pytest @@ -31,7 +31,7 @@ from ..vectorize_models import live_test_models -def enabled_vectorize_models(auth_type: str) -> List[Any]: +def enabled_vectorize_models(auth_type: str) -> list[Any]: """ This actually returns a List of `_pytest.mark.structures.ParameterSet` instances, each wrapping a dict with the needed info to test the model @@ -41,7 +41,7 @@ def enabled_vectorize_models(auth_type: str) -> List[Any]: where `tag` = "provider/model/auth_type/[0 or f]" """ all_test_models = list(live_test_models()) - all_model_ids: List[str] = [ + all_model_ids: list[str] = [ str(model_desc["model_tag"]) for model_desc in all_test_models ] # @@ -50,10 +50,10 @@ def enabled_vectorize_models(auth_type: str) -> List[Any]: for test_model in all_test_models if test_model["auth_type_name"] == auth_type ] - at_model_ids: List[str] = [ + at_model_ids: list[str] = [ str(model_desc["model_tag"]) for model_desc in at_test_models ] - at_chosen_models: List[Any] = [] + at_chosen_models: list[Any] = [] if "EMBEDDING_MODEL_TAGS" in os.environ: whitelisted_models = [ _wmd.strip() @@ -91,12 +91,12 @@ class TestVectorizeProviders: def test_vectorize_usage_auth_type_header_sync( self, sync_database: Database, - testable_vectorize_model: Dict[str, Any], + testable_vectorize_model: dict[str, Any], ) -> None: simple_tag = testable_vectorize_model["simple_tag"].lower() # switch betewen header providers according to what is needed # For the time being this is necessary on HEADER only - embedding_api_key: Union[str, EmbeddingHeadersProvider] + embedding_api_key: str | EmbeddingHeadersProvider at_tokens = testable_vectorize_model["auth_type_tokens"] at_token_lnames = {tk.accepted.lower() for tk in at_tokens} if at_token_lnames == {"x-embedding-api-key"}: @@ -201,7 +201,7 @@ def test_vectorize_usage_auth_type_header_sync( def test_vectorize_usage_auth_type_none_sync( self, sync_database: Database, - testable_vectorize_model: Dict[str, Any], + testable_vectorize_model: dict[str, Any], ) -> None: simple_tag = testable_vectorize_model["simple_tag"].lower() dimension = testable_vectorize_model.get("dimension") @@ -281,7 +281,7 @@ def test_vectorize_usage_auth_type_none_sync( def test_vectorize_usage_auth_type_shared_secret_sync( self, sync_database: Database, - testable_vectorize_model: Dict[str, Any], + testable_vectorize_model: dict[str, Any], ) -> None: simple_tag = testable_vectorize_model["simple_tag"].lower() secret_tag = testable_vectorize_model["secret_tag"] diff --git a/tests/vectorize_idiomatic/query_providers.py b/tests/vectorize_idiomatic/query_providers.py index 7ab2fccc..aace0916 100644 --- a/tests/vectorize_idiomatic/query_providers.py +++ b/tests/vectorize_idiomatic/query_providers.py @@ -17,7 +17,6 @@ import json import os import sys -from typing import List from astrapy.info import EmbeddingProviderParameter, FindEmbeddingProvidersResult @@ -113,7 +112,7 @@ def desc_param(param_data: EmbeddingProviderParameter) -> str: for test_model in all_test_models if test_model["auth_type_name"] == auth_type ] - at_model_ids: List[str] = sorted( + at_model_ids: list[str] = sorted( [str(model_desc["model_tag"]) for model_desc in at_test_models] ) if at_model_ids: diff --git a/tests/vectorize_idiomatic/vectorize_models.py b/tests/vectorize_idiomatic/vectorize_models.py index ab366c08..ed347061 100644 --- a/tests/vectorize_idiomatic/vectorize_models.py +++ b/tests/vectorize_idiomatic/vectorize_models.py @@ -16,7 +16,7 @@ import os import sys -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Iterable from astrapy.defaults import ( EMBEDDING_HEADER_API_KEY, @@ -85,7 +85,7 @@ ("voyageAI", "voyage-code-2"): CODE_TEST_ASSETS, } -USE_INSERT_ONE_MAP: Dict[Tuple[str, str], bool] = { +USE_INSERT_ONE_MAP: dict[tuple[str, str], bool] = { # ("upstageAI", "solar-1-mini-embedding"): True, } @@ -172,8 +172,7 @@ } -def live_test_models() -> Iterable[Dict[str, Any]]: - +def live_test_models() -> Iterable[dict[str, Any]]: def _from_validation(pspec: EmbeddingProviderParameter) -> int: assert pspec.parameter_type == "number" if "numericRange" in pspec.validation: @@ -181,7 +180,7 @@ def _from_validation(pspec: EmbeddingProviderParameter) -> int: m1: int = pspec.validation["numericRange"][1] return (m0 + m1) // 2 elif "options" in pspec.validation: - options: List[int] = pspec.validation["options"] + options: list[int] = pspec.validation["options"] if len(options) > 1: return options[1] else: