diff --git a/CHANGES b/CHANGES index e615509a..4ea5fab3 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,17 @@ -(master) +v. 1.4.0 ======== +DatabaseAdmin classes retain a reference to the Async/Database instance that spawned it, if any + - introduced a spawner_database parameter to database admin constructors + - database admin can retroactively set the db's working namespace upon creation of same + - Idiom `database = client.get_database(...); database.get_database_admin().create_namespace("the_namespace", update_db_namespace=True)` +Database (and AsyncDatabase) classes admit null namespace: + - default to "default_namespace" only for Astra, otherwise null + - as long as null, most operations are unavailable and error out + - a `use_namespace` method to (mutably) set the working namespace on a database instance +AstraDBDatabaseAdmin class is fully region-aware: + - can be instantiated with an endpoint (also `id` parameter aliased to `api_endpoint`) + - requires a region to be specified with an ID, unless auto-guess can be done +VectorizeOps: support for find_embedding_providers Database method Support for multiple-header embedding api keys: - `EmbeddingHeadersProvider` classes for `embedding_api_key` parameter - AWS header provider in addition to the regular one-header one @@ -8,12 +20,14 @@ Testing: - restructure CI to fully support HCD alongside Astra DB - add details for testing new embedding providers + v. 1.3.1 ======== Fixed bug in parsing endpoint domain names containing hyphens (#287), by @bradfordcp Added isort for source code formatting Updated abstractions diagram in README for non-Astra environments + v. 1.3.0 ======== Integration testing covers Astra and nonAstra smoothly: @@ -39,6 +53,7 @@ Remove several long-deprecated methods from **core API** (i.e. internal changes) AsyncAstraDB.truncate_collection => AsyncAstraDBCollectionclear Add support for null tokens in the core library + v. 1.2.1 ======== Raise default chunk size for insert_many to 50 diff --git a/README.md b/README.md index 4563828a..32737415 100644 --- a/README.md +++ b/README.md @@ -483,6 +483,12 @@ from astrapy.info import ( CollectionVectorOptions, CollectionOptions, CollectionDescriptor, + EmbeddingProviderParameter, + EmbeddingProviderModel, + EmbeddingProviderToken, + EmbeddingProviderAuthentication, + EmbeddingProvider, + FindEmbeddingProvidersResult, ) ``` diff --git a/astrapy/__init__.py b/astrapy/__init__.py index c076d9a0..1395435c 100644 --- a/astrapy/__init__.py +++ b/astrapy/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import importlib.metadata import os diff --git a/astrapy/admin.py b/astrapy/admin.py index 64587558..73be9672 100644 --- a/astrapy/admin.py +++ b/astrapy/admin.py @@ -18,11 +18,13 @@ import logging import re import time +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx +from deprecation import DeprecatedWarning from astrapy.api_commander import APICommander from astrapy.authentication import coerce_token_provider @@ -39,7 +41,7 @@ ops_recast_method_sync, to_dataapi_timeout_exception, ) -from astrapy.info import AdminDatabaseInfo, DatabaseInfo +from astrapy.info import AdminDatabaseInfo, DatabaseInfo, FindEmbeddingProvidersResult if TYPE_CHECKING: from astrapy import AsyncDatabase, Database @@ -90,6 +92,10 @@ } API_PATH_ENV_MAP = { + Environment.PROD: "/api/json", + Environment.DEV: "/api/json", + Environment.TEST: "/api/json", + # Environment.DSE: "", Environment.HCD: "", Environment.CASSANDRA: "", @@ -97,6 +103,10 @@ } API_VERSION_ENV_MAP = { + Environment.PROD: "/v1", + Environment.DEV: "/v1", + Environment.TEST: "/v1", + # Environment.DSE: "v1", Environment.HCD: "v1", Environment.CASSANDRA: "v1", @@ -261,7 +271,7 @@ async def async_fetch_raw_database_info_from_id_token( def fetch_database_info( api_endpoint: str, token: Optional[str], - namespace: str, + namespace: Optional[str], max_time_ms: Optional[int] = None, ) -> Optional[DatabaseInfo]: """ @@ -271,6 +281,7 @@ def fetch_database_info( api_endpoint: a full API endpoint for the Data Api. token: a valid token to access the database information. namespace: the desired namespace that will be used in the result. + If not specified, the resulting database info will show it as None. max_time_ms: a timeout, in milliseconds, for waiting on a response. Returns: @@ -288,7 +299,7 @@ def fetch_database_info( max_time_ms=max_time_ms, ) raw_info = gd_response["info"] - if namespace not in raw_info["keyspaces"]: + if namespace is not None and namespace not in raw_info["keyspaces"]: raise DevOpsAPIException(f"Namespace {namespace} not found on DB.") else: return DatabaseInfo( @@ -306,7 +317,7 @@ def fetch_database_info( async def async_fetch_database_info( api_endpoint: str, token: Optional[str], - namespace: str, + namespace: Optional[str], max_time_ms: Optional[int] = None, ) -> Optional[DatabaseInfo]: """ @@ -317,6 +328,7 @@ async def async_fetch_database_info( api_endpoint: a full API endpoint for the Data Api. token: a valid token to access the database information. namespace: the desired namespace that will be used in the result. + If not specified, the resulting database info will show it as None. max_time_ms: a timeout, in milliseconds, for waiting on a response. Returns: @@ -334,7 +346,7 @@ async def async_fetch_database_info( max_time_ms=max_time_ms, ) raw_info = gd_response["info"] - if namespace not in raw_info["keyspaces"]: + if namespace is not None and namespace not in raw_info["keyspaces"]: raise DevOpsAPIException(f"Namespace {namespace} not found on DB.") else: return DatabaseInfo( @@ -383,6 +395,86 @@ def _recast_as_admin_database_info( ) +def normalize_api_endpoint( + id_or_endpoint: str, + region: Optional[str], + token: TokenProvider, + environment: str, + max_time_ms: Optional[int] = None, +) -> str: + """ + Ensure that a id(+region) / endpoint init signature is normalized into + an api_endpoint string. + + This is an impure function: if necessary, attempt a DevOps API call to + integrate the information (i.e. if a DB ID without region is passed). + + This function is tasked with raising an exception if region is passed along + with an API endpoint (and they do not match). + + Args: + id_or_endpoint: either the Database ID or a full standard endpoint. + region: a string with the database region. + token: a TokenProvider for the possible DevOps request to issue. + environment: one of the Astra DB `astrapy.constants.Environment` values. + max_time_ms: used in case the DevOps API request is necessary. + + Returns: + a normalized API Endpoint string (unless it raises an exception). + """ + _api_endpoint: str + parsed_endpoint = parse_api_endpoint(id_or_endpoint) + if parsed_endpoint is not None: + if region is not None and region != parsed_endpoint.region: + raise ValueError( + "An explicit `region` parameter is provided, which does not match " + "the supplied API endpoint. Please refrain from specifying `region`." + ) + _api_endpoint = id_or_endpoint + else: + # it's a genuine ID + _region: str + if region: + _region = region + else: + logger.info(f"fetching raw database info for {id_or_endpoint}") + this_db_info = fetch_raw_database_info_from_id_token( + id=id_or_endpoint, + token=token.get_token(), + environment=environment, + max_time_ms=max_time_ms, + ) + logger.info(f"finished fetching raw database info for {id_or_endpoint}") + _region = this_db_info["info"]["region"] + _api_endpoint = build_api_endpoint( + environment=environment, + database_id=id_or_endpoint, + region=_region, + ) + return _api_endpoint.strip("/") + + +def normalize_id_endpoint_parameters( + id: Optional[str], api_endpoint: Optional[str] +) -> str: + if id is None: + if api_endpoint is None: + raise ValueError( + "Exactly one of the `id` and `api_endpoint` " + "synonymous parameters must be passed." + ) + else: + return api_endpoint + else: + if api_endpoint is not None: + raise ValueError( + "The `id` and `api_endpoint` synonymous parameters " + "cannot be supplied at the same time." + ) + else: + return id + + class AstraDBAdmin: """ An "admin" object, able to perform administrative tasks at the databases @@ -834,6 +926,7 @@ def create_database( ) return AstraDBDatabaseAdmin.from_astra_db_admin( id=new_database_id, + region=region, astra_db_admin=self, ) else: @@ -934,6 +1027,7 @@ async def async_create_database( ) return AstraDBDatabaseAdmin.from_astra_db_admin( id=new_database_id, + region=region, astra_db_admin=self, ) else: @@ -1098,12 +1192,31 @@ async def async_drop_database( f"Could not issue a successful terminate-database DevOps API request for {id}." ) - def get_database_admin(self, id: str) -> AstraDBDatabaseAdmin: + def get_database_admin( + self, + id: Optional[str] = None, + *, + api_endpoint: Optional[str] = None, + region: Optional[str] = None, + max_time_ms: Optional[int] = None, + ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin object for admin work within a certain database. Args: - id: the ID of the target database, e. g. "01234567-89ab-cdef-0123-456789abcdef". + id: the target database ID (e.g. `01234567-89ab-cdef-0123-456789abcdef`) + or the corresponding API Endpoint + (e.g. `https://-.apps.astra.datastax.com`). + api_endpoint: a named alias for the `id` first (positional) parameter, + with the same meaning. It cannot be passed together with `id`. + region: the region to use for connecting to the database. The + database must be located in that region. + The region cannot be specified when the API endoint is used as `id`. + Note that if this parameter is not passed, and cannot be inferred + from the API endpoint, an additional DevOps API request is made + to determine the default region and use it subsequently. + max_time_ms: a timeout, in milliseconds, for the DevOps API + HTTP request should it be necessary (see the `region` argument). Returns: An AstraDBDatabaseAdmin instance representing the requested database. @@ -1123,15 +1236,20 @@ def get_database_admin(self, id: str) -> AstraDBDatabaseAdmin: `create_database` method. """ + _id_or_endpoint = normalize_id_endpoint_parameters(id, api_endpoint) + return AstraDBDatabaseAdmin.from_astra_db_admin( - id=id, + id=_id_or_endpoint, + region=region, astra_db_admin=self, + max_time_ms=max_time_ms, ) def get_database( self, - id: str, + id: Optional[str] = None, *, + api_endpoint: Optional[str] = None, token: Optional[Union[str, TokenProvider]] = None, namespace: Optional[str] = None, region: Optional[str] = None, @@ -1144,7 +1262,11 @@ def get_database( when doing data-level work (such as creating/managing collections). Args: - id: e. g. "01234567-89ab-cdef-0123-456789abcdef". + id: the target database ID (e.g. `01234567-89ab-cdef-0123-456789abcdef`) + or the corresponding API Endpoint + (e.g. `https://-.apps.astra.datastax.com`). + api_endpoint: a named alias for the `id` first (positional) parameter, + with the same meaning. It cannot be passed together with `id`. token: if supplied, is passed to the Database instead of the one set for this object. This can be either a literal token string or a subclass of @@ -1155,11 +1277,10 @@ def get_database( the default namespace for the target database. region: the region to use for connecting to the database. The database must be located in that region. - Note that if this parameter is not passed, an additional - DevOps API request is made to determine the default region - and use it subsequently. - If both `namespace` and `region` are missing, a single - DevOps API request is made. + The region cannot be specified when the API endoint is used as `id`. + Note that if this parameter is not passed, and cannot be inferred + from the API endpoint, an additional DevOps API request is made + to determine the default region and use it subsequently. api_path: path to append to the API Endpoint. In typical usage, this should be left to its default of "/api/json". api_version: version specifier to append to the API path. In typical @@ -1187,30 +1308,36 @@ def get_database( # lazy importing here to avoid circular dependency from astrapy import Database - # need to inspect for values? - this_db_info: Optional[AdminDatabaseInfo] = None - # handle overrides + _id_or_endpoint = normalize_id_endpoint_parameters(id, api_endpoint) + _token = coerce_token_provider(token) or self.token_provider + + normalized_api_endpoint = normalize_api_endpoint( + id_or_endpoint=_id_or_endpoint, + region=region, + token=_token, + environment=self.environment, + max_time_ms=max_time_ms, + ) + + _namespace: str if namespace: _namespace = namespace else: - if this_db_info is None: - this_db_info = self.database_info(id, max_time_ms=max_time_ms) + parsed_api_endpoint = parse_api_endpoint(normalized_api_endpoint) + if parsed_api_endpoint is None: + raise ValueError( + f"Cannot parse the API endpoint ({normalized_api_endpoint})." + ) + + this_db_info = self.database_info( + parsed_api_endpoint.database_id, + max_time_ms=max_time_ms, + ) _namespace = this_db_info.info.namespace - if region: - _region = region - else: - if this_db_info is None: - this_db_info = self.database_info(id, max_time_ms=max_time_ms) - _region = this_db_info.info.region - _api_endpoint = build_api_endpoint( - environment=self.environment, - database_id=id, - region=_region, - ) return Database( - api_endpoint=_api_endpoint, + api_endpoint=normalized_api_endpoint, token=_token, namespace=_namespace, caller_name=self._caller_name, @@ -1222,8 +1349,9 @@ def get_database( def get_async_database( self, - id: str, + id: Optional[str] = None, *, + api_endpoint: Optional[str] = None, token: Optional[Union[str, TokenProvider]] = None, namespace: Optional[str] = None, region: Optional[str] = None, @@ -1240,6 +1368,7 @@ def get_async_database( return self.get_database( id=id, + api_endpoint=api_endpoint, token=token, namespace=namespace, region=region, @@ -1256,6 +1385,7 @@ class DatabaseAdmin(ABC): """ environment: str + spawner_database: Union[Database, AsyncDatabase] @abstractmethod def list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: @@ -1263,7 +1393,13 @@ def list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: ... @abstractmethod - def create_namespace(self, name: str, *pargs: Any, **kwargs: Any) -> Dict[str, Any]: + def create_namespace( + self, + name: str, + *, + update_db_namespace: Optional[bool] = None, + **kwargs: Any, + ) -> Dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. """ @@ -1286,7 +1422,7 @@ async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: @abstractmethod async def async_create_namespace( - self, name: str, *pargs: Any, **kwargs: Any + self, name: str, *, update_db_namespace: Optional[bool] = None, **kwargs: Any ) -> Dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. @@ -1314,6 +1450,23 @@ def get_async_database(self, *pargs: Any, **kwargs: Any) -> AsyncDatabase: """Get an AsyncDatabase object from this database admin.""" ... + @abstractmethod + def find_embedding_providers( + self, *pargs: Any, **kwargs: Any + ) -> FindEmbeddingProvidersResult: + """Query the Data API for the available embedding providers.""" + ... + + @abstractmethod + async def async_find_embedding_providers( + self, *pargs: Any, **kwargs: Any + ) -> FindEmbeddingProvidersResult: + """ + Query the Data API for the available embedding providers. + (Async version of the method.) + """ + ... + class AstraDBDatabaseAdmin(DatabaseAdmin): """ @@ -1325,10 +1478,20 @@ class AstraDBDatabaseAdmin(DatabaseAdmin): created by a method call on an AstraDBAdmin. Args: - id: e. g. "01234567-89ab-cdef-0123-456789abcdef". + id: the target database ID (e.g. `01234567-89ab-cdef-0123-456789abcdef`) + or the corresponding API Endpoint + (e.g. `https://-.apps.astra.datastax.com`). + api_endpoint: a named alias for the `id` first (positional) parameter, + with the same meaning. It cannot be passed together with `id`. token: an access token with enough permission to perform admin tasks. This can be either a literal token string or a subclass of `astrapy.authentication.TokenProvider`. + region: the region to use for connecting to the database. The + database must be located in that region. + The region cannot be specified when the API endoint is used as `id`. + Note that if this parameter is not passed, and cannot be inferred + from the API endpoint, an additional DevOps API request is made + to determine the default region and use it subsequently. environment: a label, whose value is one of Environment.PROD (default), Environment.DEV or Environment.TEST. caller_name: name of the application, or framework, on behalf of which @@ -1340,6 +1503,20 @@ class AstraDBDatabaseAdmin(DatabaseAdmin): determined from the API Endpoint. dev_ops_api_version: this can specify a custom version of the DevOps API (such as "v2"). Generally not needed. + api_path: path to append to the API Endpoint. In typical usage, this + class is created by a method such as `Database.get_database_admin()`, + which passes the matching value. Generally to be left to its Astra DB + default of "/api/json". + api_version: version specifier to append to the API path. In typical + usage, this class is created by a method such as + `Database.get_database_admin()`, which passes the matching value. + Generally to be left to its Astra DB default of "/v1". + spawner_database: either a Database or an AsyncDatabase instance. This represents + the database class which spawns this admin object, so that, if required, + a namespace creation can retroactively "use" the new namespace in the spawner. + Used to enable the Async/Database.get_admin_database().create_namespace() pattern. + max_time_ms: a timeout, in milliseconds, for the DevOps API + HTTP request should it be necessary (see the `region` argument). Example: >>> from astrapy import DataAPIClient @@ -1358,27 +1535,101 @@ class AstraDBDatabaseAdmin(DatabaseAdmin): def __init__( self, - id: str, + id: Optional[str] = None, *, + api_endpoint: Optional[str] = None, token: Optional[Union[str, TokenProvider]] = None, + region: Optional[str] = None, environment: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, dev_ops_url: Optional[str] = None, dev_ops_api_version: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + spawner_database: Optional[Union[Database, AsyncDatabase]] = None, + max_time_ms: Optional[int] = None, ) -> None: - self.id = id + # lazy import here to avoid circular dependency + from astrapy.database import Database + self.token_provider = coerce_token_provider(token) self.environment = (environment or Environment.PROD).lower() + + _id_or_endpoint = normalize_id_endpoint_parameters(id, api_endpoint) + + normalized_api_endpoint = normalize_api_endpoint( + id_or_endpoint=_id_or_endpoint, + region=region, + token=self.token_provider, + environment=self.environment, + max_time_ms=max_time_ms, + ) + + self.api_endpoint = normalized_api_endpoint + parsed_api_endpoint = parse_api_endpoint(self.api_endpoint) + if parsed_api_endpoint is None: + raise ValueError( + f"Cannot parse the provided API endpoint ({self.api_endpoint})." + ) + + self._database_id = parsed_api_endpoint.database_id + self._region = parsed_api_endpoint.region + if parsed_api_endpoint.environment != self.environment: + raise ValueError( + "Environment mismatch between client and provided " + "API endpoint. You can try adding " + f'`environment="{parsed_api_endpoint.environment}"` ' + "to the class constructor." + ) + # + self.caller_name = caller_name + self.caller_version = caller_version + self._astra_db_admin = AstraDBAdmin( token=self.token_provider, environment=self.environment, - caller_name=caller_name, - caller_version=caller_version, + caller_name=self.caller_name, + caller_version=self.caller_version, dev_ops_url=dev_ops_url, dev_ops_api_version=dev_ops_api_version, ) + # API Commander (for the vectorizeOps invocations) + self.api_path = ( + api_path if api_path is not None else API_PATH_ENV_MAP[self.environment] + ) + self.api_version = ( + api_version + if api_version is not None + else API_VERSION_ENV_MAP[self.environment] + ) + self._commander_headers = { + DEFAULT_AUTH_HEADER: self.token_provider.get_token(), + } + self._api_commander = APICommander( + api_endpoint=self.api_endpoint, + path="/".join(comp for comp in [self.api_path, self.api_version] if comp), + headers=self._commander_headers, + callers=[(self.caller_name, self.caller_version)], + ) + + if spawner_database is not None: + self.spawner_database = spawner_database + else: + # leaving the namespace to its per-environment default + # (a task for the Database) + self.spawner_database = Database( + api_endpoint=self.api_endpoint, + token=self.token_provider, + namespace=None, + caller_name=self.caller_name, + caller_version=self.caller_version, + environment=self.environment, + api_path=self.api_path, + api_version=self.api_version, + ) + def __repr__(self) -> str: env_desc: str if self.environment == Environment.PROD: @@ -1386,7 +1637,7 @@ def __repr__(self) -> str: else: env_desc = f', environment="{self.environment}"' return ( - f'{self.__class__.__name__}(id="{self.id}", ' + f'{self.__class__.__name__}(api_endpoint="{self.api_endpoint}", ' f'"{str(self.token_provider)[:12]}..."{env_desc})' ) @@ -1394,7 +1645,7 @@ def __eq__(self, other: Any) -> bool: if isinstance(other, AstraDBDatabaseAdmin): return all( [ - self.id == other.id, + self.api_endpoint == other.api_endpoint, self.token_provider == other.token_provider, self.environment == other.environment, self._astra_db_admin == other._astra_db_admin, @@ -1407,6 +1658,7 @@ def _copy( self, id: Optional[str] = None, token: Optional[Union[str, TokenProvider]] = None, + region: Optional[str] = None, environment: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, @@ -1414,8 +1666,9 @@ def _copy( dev_ops_api_version: Optional[str] = None, ) -> AstraDBDatabaseAdmin: return AstraDBDatabaseAdmin( - id=id or self.id, + id=id or self._database_id, token=coerce_token_provider(token) or self.token_provider, + region=region or self._region, environment=environment or self.environment, caller_name=caller_name or self._astra_db_admin._caller_name, caller_version=caller_version or self._astra_db_admin._caller_version, @@ -1487,17 +1740,53 @@ def set_caller( logger.info(f"setting caller to {caller_name}/{caller_version}") self._astra_db_admin.set_caller(caller_name, caller_version) + @property + def id(self) -> str: + """ + The ID of this database admin. + + Example: + >>> my_db_admin.id + '01234567-89ab-cdef-0123-456789abcdef' + """ + return self._database_id + + @property + def region(self) -> str: + """ + The region for this database admin. + + Example: + >>> my_db_admin.region + 'us-east-1' + """ + return self._region + @staticmethod def from_astra_db_admin( - id: str, *, astra_db_admin: AstraDBAdmin + id: str, + *, + region: Optional[str], + astra_db_admin: AstraDBAdmin, + max_time_ms: Optional[int] = None, ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin from an AstraDBAdmin and a database ID. Args: - id: e. g. "01234567-89ab-cdef-0123-456789abcdef". + id: the target database ID (e.g. `01234567-89ab-cdef-0123-456789abcdef`) + or the corresponding API Endpoint + (e.g. `https://-.apps.astra.datastax.com`). + region: the region to use for connecting to the database. The + database must be located in that region. + The region cannot be specified when the API endoint is used as `id`. + Note that if this parameter is not passed, and cannot be inferred + from the API endpoint, an additional DevOps API request is made + to determine the default region and use it subsequently. astra_db_admin: an AstraDBAdmin object that has visibility over the target database. + max_time_ms: a timeout, in milliseconds, for the DevOps API + HTTP request should it be necessary (see the `region` argument). Returns: An AstraDBDatabaseAdmin object, for admin work within the database. @@ -1522,11 +1811,13 @@ def from_astra_db_admin( return AstraDBDatabaseAdmin( id=id, token=astra_db_admin.token_provider, + region=region, environment=astra_db_admin.environment, caller_name=astra_db_admin._caller_name, caller_version=astra_db_admin._caller_version, dev_ops_url=astra_db_admin._dev_ops_url, dev_ops_api_version=astra_db_admin._dev_ops_api_version, + max_time_ms=max_time_ms, ) @staticmethod @@ -1582,6 +1873,7 @@ def from_api_endpoint( return AstraDBDatabaseAdmin( id=parsed_api_endpoint.database_id, token=token, + region=parsed_api_endpoint.region, environment=parsed_api_endpoint.environment, caller_name=caller_name, caller_version=caller_version, @@ -1609,12 +1901,12 @@ def info(self, *, max_time_ms: Optional[int] = None) -> AdminDatabaseInfo: 'us-east1' """ - logger.info(f"getting info ('{self.id}')") + logger.info(f"getting info ('{self._database_id}')") req_response = self._astra_db_admin.database_info( - id=self.id, + id=self._database_id, max_time_ms=max_time_ms, ) - logger.info(f"finished getting info ('{self.id}')") + logger.info(f"finished getting info ('{self._database_id}')") return req_response # type: ignore[no-any-return] async def async_info( @@ -1640,12 +1932,12 @@ async def async_info( >>> asyncio.run(wait_until_active(admin_for_my_db)) """ - logger.info(f"getting info ('{self.id}'), async") + logger.info(f"getting info ('{self._database_id}'), async") req_response = await self._astra_db_admin.async_database_info( - id=self.id, + id=self._database_id, max_time_ms=max_time_ms, ) - logger.info(f"finished getting info ('{self.id}'), async") + logger.info(f"finished getting info ('{self._database_id}'), async") return req_response # type: ignore[no-any-return] def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: @@ -1663,9 +1955,9 @@ def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: ['default_keyspace', 'staging_namespace'] """ - logger.info(f"getting namespaces ('{self.id}')") + logger.info(f"getting namespaces ('{self._database_id}')") info = self.info(max_time_ms=max_time_ms) - logger.info(f"finished getting namespaces ('{self.id}')") + logger.info(f"finished getting namespaces ('{self._database_id}')") if info.raw_info is None: raise DevOpsAPIException("Could not get the namespace list.") else: @@ -1697,9 +1989,9 @@ async def async_list_namespaces( True """ - logger.info(f"getting namespaces ('{self.id}'), async") + logger.info(f"getting namespaces ('{self._database_id}'), async") info = await self.async_info(max_time_ms=max_time_ms) - logger.info(f"finished getting namespaces ('{self.id}'), async") + logger.info(f"finished getting namespaces ('{self._database_id}'), async") if info.raw_info is None: raise DevOpsAPIException("Could not get the namespace list.") else: @@ -1711,7 +2003,9 @@ def create_namespace( name: str, *, wait_until_active: bool = True, + update_db_namespace: Optional[bool] = None, max_time_ms: Optional[int] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Create a namespace in this database as requested, @@ -1727,6 +2021,9 @@ def create_namespace( creation request to the DevOps API, and it will be responsibility of the caller to check the database status/namespace availability before working with it. + update_db_namespace: if True, the `Database` or `AsyncDatabase` class + that spawned this DatabaseAdmin, if any, gets updated to work on + the newly-created namespace starting when this method returns. max_time_ms: a timeout, in milliseconds, for the whole requested operation to complete. Note that a timeout is no guarantee that the creation request @@ -1748,20 +2045,20 @@ def create_namespace( timeout_manager = MultiCallTimeoutManager( overall_max_time_ms=max_time_ms, exception_type="devops_api" ) - logger.info(f"creating namespace '{name}' on '{self.id}'") + logger.info(f"creating namespace '{name}' on '{self._database_id}'") cn_response = self._astra_db_admin._astra_db_ops.create_keyspace( - database=self.id, + database=self._database_id, keyspace=name, timeout_info=base_timeout_info(max_time_ms), ) logger.info( - f"devops api returned from creating namespace '{name}' on '{self.id}'" + f"devops api returned from creating namespace '{name}' on '{self._database_id}'" ) if cn_response is not None and name == cn_response.get("name"): if wait_until_active: last_status_seen = STATUS_MAINTENANCE while last_status_seen == STATUS_MAINTENANCE: - logger.info(f"sleeping to poll for status of '{self.id}'") + logger.info(f"sleeping to poll for status of '{self._database_id}'") time.sleep(DATABASE_POLL_NAMESPACE_SLEEP_TIME) last_status_seen = self.info( max_time_ms=timeout_manager.remaining_timeout_ms(), @@ -1773,7 +2070,11 @@ def create_namespace( # is the namespace found? if name not in self.list_namespaces(): raise DevOpsAPIException("Could not create the namespace.") - logger.info(f"finished creating namespace '{name}' on '{self.id}'") + logger.info( + f"finished creating namespace '{name}' on '{self._database_id}'" + ) + if update_db_namespace: + self.spawner_database.use_namespace(name) return {"ok": 1} else: raise DevOpsAPIException( @@ -1787,7 +2088,9 @@ async def async_create_namespace( # type: ignore[override] name: str, *, wait_until_active: bool = True, + update_db_namespace: Optional[bool] = None, max_time_ms: Optional[int] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Create a namespace in this database as requested, @@ -1804,6 +2107,9 @@ async def async_create_namespace( # type: ignore[override] creation request to the DevOps API, and it will be responsibility of the caller to check the database status/namespace availability before working with it. + update_db_namespace: if True, the `Database` or `AsyncDatabase` class + that spawned this DatabaseAdmin, if any, gets updated to work on + the newly-created namespace starting when this method returns. max_time_ms: a timeout, in milliseconds, for the whole requested operation to complete. Note that a timeout is no guarantee that the creation request @@ -1823,21 +2129,23 @@ async def async_create_namespace( # type: ignore[override] timeout_manager = MultiCallTimeoutManager( overall_max_time_ms=max_time_ms, exception_type="devops_api" ) - logger.info(f"creating namespace '{name}' on '{self.id}', async") + logger.info(f"creating namespace '{name}' on '{self._database_id}', async") cn_response = await self._astra_db_admin._astra_db_ops.async_create_keyspace( - database=self.id, + database=self._database_id, keyspace=name, timeout_info=base_timeout_info(max_time_ms), ) logger.info( f"devops api returned from creating namespace " - f"'{name}' on '{self.id}', async" + f"'{name}' on '{self._database_id}', async" ) if cn_response is not None and name == cn_response.get("name"): if wait_until_active: last_status_seen = STATUS_MAINTENANCE while last_status_seen == STATUS_MAINTENANCE: - logger.info(f"sleeping to poll for status of '{self.id}', async") + logger.info( + f"sleeping to poll for status of '{self._database_id}', async" + ) await asyncio.sleep(DATABASE_POLL_NAMESPACE_SLEEP_TIME) last_db_info = await self.async_info( max_time_ms=timeout_manager.remaining_timeout_ms(), @@ -1850,7 +2158,11 @@ async def async_create_namespace( # type: ignore[override] # is the namespace found? if name not in await self.async_list_namespaces(): raise DevOpsAPIException("Could not create the namespace.") - logger.info(f"finished creating namespace '{name}' on '{self.id}', async") + logger.info( + f"finished creating namespace '{name}' on '{self._database_id}', async" + ) + if update_db_namespace: + self.spawner_database.use_namespace(name) return {"ok": 1} else: raise DevOpsAPIException( @@ -1899,20 +2211,20 @@ def drop_namespace( timeout_manager = MultiCallTimeoutManager( overall_max_time_ms=max_time_ms, exception_type="devops_api" ) - logger.info(f"dropping namespace '{name}' on '{self.id}'") + logger.info(f"dropping namespace '{name}' on '{self._database_id}'") dk_response = self._astra_db_admin._astra_db_ops.delete_keyspace( - database=self.id, + database=self._database_id, keyspace=name, timeout_info=base_timeout_info(max_time_ms), ) logger.info( - f"devops api returned from dropping namespace '{name}' on '{self.id}'" + f"devops api returned from dropping namespace '{name}' on '{self._database_id}'" ) if dk_response == name: if wait_until_active: last_status_seen = STATUS_MAINTENANCE while last_status_seen == STATUS_MAINTENANCE: - logger.info(f"sleeping to poll for status of '{self.id}'") + logger.info(f"sleeping to poll for status of '{self._database_id}'") time.sleep(DATABASE_POLL_NAMESPACE_SLEEP_TIME) last_status_seen = self.info( max_time_ms=timeout_manager.remaining_timeout_ms(), @@ -1924,7 +2236,9 @@ def drop_namespace( # is the namespace found? if name in self.list_namespaces(): raise DevOpsAPIException("Could not drop the namespace.") - logger.info(f"finished dropping namespace '{name}' on '{self.id}'") + logger.info( + f"finished dropping namespace '{name}' on '{self._database_id}'" + ) return {"ok": 1} else: raise DevOpsAPIException( @@ -1973,21 +2287,23 @@ async def async_drop_namespace( # type: ignore[override] timeout_manager = MultiCallTimeoutManager( overall_max_time_ms=max_time_ms, exception_type="devops_api" ) - logger.info(f"dropping namespace '{name}' on '{self.id}', async") + logger.info(f"dropping namespace '{name}' on '{self._database_id}', async") dk_response = await self._astra_db_admin._astra_db_ops.async_delete_keyspace( - database=self.id, + database=self._database_id, keyspace=name, timeout_info=base_timeout_info(max_time_ms), ) logger.info( f"devops api returned from dropping namespace " - f"'{name}' on '{self.id}', async" + f"'{name}' on '{self._database_id}', async" ) if dk_response == name: if wait_until_active: last_status_seen = STATUS_MAINTENANCE while last_status_seen == STATUS_MAINTENANCE: - logger.info(f"sleeping to poll for status of '{self.id}', async") + logger.info( + f"sleeping to poll for status of '{self._database_id}', async" + ) await asyncio.sleep(DATABASE_POLL_NAMESPACE_SLEEP_TIME) last_db_info = await self.async_info( max_time_ms=timeout_manager.remaining_timeout_ms(), @@ -2000,7 +2316,9 @@ async def async_drop_namespace( # type: ignore[override] # is the namespace found? if name in await self.async_list_namespaces(): raise DevOpsAPIException("Could not drop the namespace.") - logger.info(f"finished dropping namespace '{name}' on '{self.id}', async") + logger.info( + f"finished dropping namespace '{name}' on '{self._database_id}', async" + ) return {"ok": 1} else: raise DevOpsAPIException( @@ -2050,13 +2368,13 @@ def drop( which avoids using a deceased database any further. """ - logger.info(f"dropping this database ('{self.id}')") + logger.info(f"dropping this database ('{self._database_id}')") return self._astra_db_admin.drop_database( # type: ignore[no-any-return] - id=self.id, + id=self._database_id, wait_until_active=wait_until_active, max_time_ms=max_time_ms, ) - logger.info(f"finished dropping this database ('{self.id}')") + logger.info(f"finished dropping this database ('{self._database_id}')") async def async_drop( self, @@ -2099,13 +2417,13 @@ async def async_drop( which avoids using a deceased database any further. """ - logger.info(f"dropping this database ('{self.id}'), async") + logger.info(f"dropping this database ('{self._database_id}'), async") return await self._astra_db_admin.async_drop_database( # type: ignore[no-any-return] - id=self.id, + id=self._database_id, wait_until_active=wait_until_active, max_time_ms=max_time_ms, ) - logger.info(f"finished dropping this database ('{self.id}'), async") + logger.info(f"finished dropping this database ('{self._database_id}'), async") def get_database( self, @@ -2118,7 +2436,7 @@ def get_database( max_time_ms: Optional[int] = None, ) -> Database: """ - Create a Database instance out of this class for working with the data in it. + Create a Database instance from this database admin, for data-related tasks. Args: token: if supplied, is passed to the Database instead of @@ -2128,8 +2446,8 @@ def get_database( `astrapy.authentication.TokenProvider`. namespace: an optional namespace to set in the resulting Database. The same default logic as for `AstraDBAdmin.get_database` applies. - region: an optional region for connecting to the database Data API endpoint. - The same default logic as for `AstraDBAdmin.get_database` applies. + region: *This parameter is deprecated and should not be used.* + Ignored in the method. api_path: path to append to the API Endpoint. In typical usage, this should be left to its default of "/api/json". api_version: version specifier to append to the API path. In typical @@ -2149,11 +2467,22 @@ def get_database( see the AstraDBAdmin class. """ + if region is not None: + the_warning = DeprecatedWarning( + "The 'region' parameter is deprecated in this method and will be ignored.", + deprecated_in="1.3.2", + removed_in="2.0.0", + details="The database class whose method is invoked already has a region set.", + ) + warnings.warn( + the_warning, + stacklevel=2, + ) + return self._astra_db_admin.get_database( - id=self.id, + id=self.api_endpoint, token=token, namespace=namespace, - region=region, api_path=api_path, api_version=api_version, max_time_ms=max_time_ms, @@ -2186,6 +2515,93 @@ def get_async_database( max_time_ms=max_time_ms, ).to_async() + def find_embedding_providers( + self, *, max_time_ms: Optional[int] = None + ) -> FindEmbeddingProvidersResult: + """ + Query the API for the full information on available embedding providers. + + Args: + max_time_ms: a timeout, in milliseconds, for the DevOps API request. + + Returns: + A `FindEmbeddingProvidersResult` object with the complete information + returned by the API about available embedding providers + + Example (output abridged and indented for clarity): + >>> admin_for_my_db.find_embedding_providers() + FindEmbeddingProvidersResult(embedding_providers=..., openai, ...) + >>> admin_for_my_db.find_embedding_providers().embedding_providers + { + 'openai': EmbeddingProvider( + display_name='OpenAI', + models=[ + EmbeddingProviderModel(name='text-embedding-3-small'), + ... + ] + ), + ... + } + """ + + logger.info("getting list of embedding providers") + fe_response = self._api_commander.request( + payload={"findEmbeddingProviders": {}}, + timeout_info=base_timeout_info(max_time_ms), + ) + if "embeddingProviders" not in fe_response.get("status", {}): + raise DataAPIFaultyResponseException( + text="Faulty response from findEmbeddingProviders API command.", + raw_response=fe_response, + ) + else: + logger.info("finished getting list of embedding providers") + return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) + + async def async_find_embedding_providers( + self, *, max_time_ms: Optional[int] = None + ) -> FindEmbeddingProvidersResult: + """ + Query the API for the full information on available embedding providers. + Async version of the method, for use in an asyncio context. + + Args: + max_time_ms: a timeout, in milliseconds, for the DevOps API request. + + Returns: + A `FindEmbeddingProvidersResult` object with the complete information + returned by the API about available embedding providers + + Example (output abridged and indented for clarity): + >>> admin_for_my_db.find_embedding_providers() + FindEmbeddingProvidersResult(embedding_providers=..., openai, ...) + >>> admin_for_my_db.find_embedding_providers().embedding_providers + { + 'openai': EmbeddingProvider( + display_name='OpenAI', + models=[ + EmbeddingProviderModel(name='text-embedding-3-small'), + ... + ] + ), + ... + } + """ + + logger.info("getting list of embedding providers, async") + fe_response = await self._api_commander.async_request( + payload={"findEmbeddingProviders": {}}, + timeout_info=base_timeout_info(max_time_ms), + ) + if "embeddingProviders" not in fe_response.get("status", {}): + raise DataAPIFaultyResponseException( + text="Faulty response from findEmbeddingProviders API command.", + raw_response=fe_response, + ) + else: + logger.info("finished getting list of embedding providers, async") + return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) + class DataAPIDatabaseAdmin(DatabaseAdmin): """ @@ -2209,12 +2625,20 @@ class DataAPIDatabaseAdmin(DatabaseAdmin): environment: a label, whose value is one of Environment.OTHER (default) or other non-Astra environment values in the `Environment` enum. api_path: path to append to the API Endpoint. In typical usage, this - should be left to its default of "". + class is created by a method such as `Database.get_database_admin()`, + which passes the matching value. Defaults to this portion of the path + being absent. api_version: version specifier to append to the API path. In typical - usage, this should be left to its default of "v1". + usage, this class is created by a method such as + `Database.get_database_admin()`, which passes the matching value. + Defaults to this portion of the path being absent. caller_name: name of the application, or framework, on behalf of which the admin API calls are performed. This ends up in the request user-agent. caller_version: version of the caller. + spawner_database: either a Database or an AsyncDatabase instance. This represents + the database class which spawns this admin object, so that, if required, + a namespace creation can retroactively "use" the new namespace in the spawner. + Used to enable the Async/Database.get_admin_database().create_namespace() pattern. Example: >>> from astrapy import DataAPIClient @@ -2245,7 +2669,11 @@ def __init__( api_version: Optional[str] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, + spawner_database: Optional[Union[Database, AsyncDatabase]] = None, ) -> None: + # lazy import here to avoid circular dependency + from astrapy.database import Database + self.environment = (environment or Environment.OTHER).lower() self.token_provider = coerce_token_provider(token) self.api_endpoint = api_endpoint @@ -2253,19 +2681,36 @@ def __init__( self.caller_name = caller_name self.caller_version = caller_version # - self._api_path = api_path if api_path is not None else "" - self._api_version = api_version if api_version is not None else "" + self.api_path = api_path if api_path is not None else "" + self.api_version = api_version if api_version is not None else "" # self._commander_headers = { DEFAULT_AUTH_HEADER: self.token_provider.get_token(), } + self._api_commander = APICommander( api_endpoint=self.api_endpoint, - path="/".join(comp for comp in [self._api_path, self._api_version] if comp), + path="/".join(comp for comp in [self.api_path, self.api_version] if comp), headers=self._commander_headers, callers=[(self.caller_name, self.caller_version)], ) + if spawner_database is not None: + self.spawner_database = spawner_database + else: + # leaving the namespace to its per-environment default + # (a task for the Database) + self.spawner_database = Database( + api_endpoint=self.api_endpoint, + token=self.token_provider, + namespace=None, + caller_name=self.caller_name, + caller_version=self.caller_version, + environment=self.environment, + api_path=self.api_path, + api_version=self.api_version, + ) + def __repr__(self) -> str: env_desc = f', environment="{self.environment}"' return ( @@ -2298,8 +2743,8 @@ def _copy( api_endpoint=api_endpoint or self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, environment=environment or self.environment, - api_path=api_path or self._api_path, - api_version=api_version or self._api_version, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, caller_name=caller_name or self.caller_name, caller_version=caller_version or self.caller_version, ) @@ -2369,7 +2814,7 @@ def set_caller( self.caller_version = caller_version self._api_commander = APICommander( api_endpoint=self.api_endpoint, - path="/".join(comp for comp in [self._api_path, self._api_version] if comp), + path="/".join(comp for comp in [self.api_path, self.api_version] if comp), headers=self._commander_headers, callers=[(self.caller_name, self.caller_version)], ) @@ -2407,7 +2852,9 @@ def create_namespace( name: str, *, replication_options: Optional[Dict[str, Any]] = None, + update_db_namespace: Optional[bool] = None, max_time_ms: Optional[int] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. @@ -2420,6 +2867,9 @@ def create_namespace( replication of the namespace (across database nodes). If provided, it must have a structure similar to: `{"class": "SimpleStrategy", "replication_factor": 1}`. + update_db_namespace: if True, the `Database` or `AsyncDatabase` class + that spawned this DatabaseAdmin, if any, gets updated to work on + the newly-created namespace starting when this method returns. max_time_ms: a timeout, in milliseconds, for the whole requested operation to complete. Note that a timeout is no guarantee that the creation request @@ -2462,6 +2912,8 @@ def create_namespace( ) else: logger.info("finished creating namespace") + if update_db_namespace: + self.spawner_database.use_namespace(name) return cn_response["status"] # type: ignore[no-any-return] def drop_namespace( @@ -2543,7 +2995,9 @@ async def async_create_namespace( name: str, *, replication_options: Optional[Dict[str, Any]] = None, + update_db_namespace: Optional[bool] = None, max_time_ms: Optional[int] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. @@ -2557,6 +3011,9 @@ async def async_create_namespace( replication of the namespace (across database nodes). If provided, it must have a structure similar to: `{"class": "SimpleStrategy", "replication_factor": 1}`. + update_db_namespace: if True, the `Database` or `AsyncDatabase` class + that spawned this DatabaseAdmin, if any, gets updated to work on + the newly-created namespace starting when this method returns. max_time_ms: a timeout, in milliseconds, for the whole requested operation to complete. Note that a timeout is no guarantee that the creation request @@ -2601,6 +3058,8 @@ async def async_create_namespace( ) else: logger.info("finished creating namespace, async") + if update_db_namespace: + self.spawner_database.use_namespace(name) return cn_response["status"] # type: ignore[no-any-return] async def async_drop_namespace( @@ -2635,7 +3094,7 @@ async def async_drop_namespace( >>> admin_for_my_db.list_namespaces() ['default_keyspace'] """ - logger.info("dropping namespace") + logger.info("dropping namespace, async") dn_response = await self._api_commander.async_request( payload={"dropNamespace": {"name": name}}, timeout_info=base_timeout_info(max_time_ms), @@ -2646,7 +3105,7 @@ async def async_drop_namespace( raw_response=dn_response, ) else: - logger.info("finished dropping namespace") + logger.info("finished dropping namespace, async") return dn_response["status"] # type: ignore[no-any-return] def get_database( @@ -2667,7 +3126,8 @@ def get_database( This can be either a literal token string or a subclass of `astrapy.authentication.TokenProvider`. namespace: an optional namespace to set in the resulting Database. - If not provided, the default namespace is used. + If not provided, no namespace is set, limiting what the Database + can do until setting it with e.g. a `useNamespace` method call. api_path: path to append to the API Endpoint. In typical usage, this should be left to its default of "". api_version: version specifier to append to the API path. In typical @@ -2721,3 +3181,90 @@ def get_async_database( api_path=api_path, api_version=api_version, ).to_async() + + def find_embedding_providers( + self, *, max_time_ms: Optional[int] = None + ) -> FindEmbeddingProvidersResult: + """ + Query the API for the full information on available embedding providers. + + Args: + max_time_ms: a timeout, in milliseconds, for the DevOps API request. + + Returns: + A `FindEmbeddingProvidersResult` object with the complete information + returned by the API about available embedding providers + + Example (output abridged and indented for clarity): + >>> admin_for_my_db.find_embedding_providers() + FindEmbeddingProvidersResult(embedding_providers=..., openai, ...) + >>> admin_for_my_db.find_embedding_providers().embedding_providers + { + 'openai': EmbeddingProvider( + display_name='OpenAI', + models=[ + EmbeddingProviderModel(name='text-embedding-3-small'), + ... + ] + ), + ... + } + """ + + logger.info("getting list of embedding providers") + fe_response = self._api_commander.request( + payload={"findEmbeddingProviders": {}}, + timeout_info=base_timeout_info(max_time_ms), + ) + if "embeddingProviders" not in fe_response.get("status", {}): + raise DataAPIFaultyResponseException( + text="Faulty response from findEmbeddingProviders API command.", + raw_response=fe_response, + ) + else: + logger.info("finished getting list of embedding providers") + return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) + + async def async_find_embedding_providers( + self, *, max_time_ms: Optional[int] = None + ) -> FindEmbeddingProvidersResult: + """ + Query the API for the full information on available embedding providers. + Async version of the method, for use in an asyncio context. + + Args: + max_time_ms: a timeout, in milliseconds, for the DevOps API request. + + Returns: + A `FindEmbeddingProvidersResult` object with the complete information + returned by the API about available embedding providers + + Example (output abridged and indented for clarity): + >>> admin_for_my_db.find_embedding_providers() + FindEmbeddingProvidersResult(embedding_providers=..., openai, ...) + >>> admin_for_my_db.find_embedding_providers().embedding_providers + { + 'openai': EmbeddingProvider( + display_name='OpenAI', + models=[ + EmbeddingProviderModel(name='text-embedding-3-small'), + ... + ] + ), + ... + } + """ + + logger.info("getting list of embedding providers, async") + fe_response = await self._api_commander.async_request( + payload={"findEmbeddingProviders": {}}, + timeout_info=base_timeout_info(max_time_ms), + ) + if "embeddingProviders" not in fe_response.get("status", {}): + raise DataAPIFaultyResponseException( + text="Faulty response from findEmbeddingProviders API command.", + raw_response=fe_response, + ) + else: + logger.info("finished getting list of embedding providers, async") + return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) diff --git a/astrapy/api/__init__.py b/astrapy/api/__init__.py index 4d22daa5..066cfdaa 100644 --- a/astrapy/api/__init__.py +++ b/astrapy/api/__init__.py @@ -14,6 +14,8 @@ """Core "api" subpackage, exported here to preserve import patterns.""" +from __future__ import annotations + from astrapy.core.api import APIRequestError __all__ = [ diff --git a/astrapy/api_options.py b/astrapy/api_options.py index 88f60f04..5b84179b 100644 --- a/astrapy/api_options.py +++ b/astrapy/api_options.py @@ -18,8 +18,8 @@ from typing import Optional, TypeVar from astrapy.authentication import ( + EmbeddingAPIKeyHeaderProvider, EmbeddingHeadersProvider, - StaticEmbeddingHeadersProvider, ) AO = TypeVar("AO", bound="BaseAPIOptions") @@ -117,12 +117,12 @@ class CollectionAPIOptions(BaseAPIOptions): embedding_api_key: an `astrapy.authentication.EmbeddingHeadersProvider` object, encoding embedding-related API keys that will be passed as headers when interacting with the collection (on each Data API request). - The default value is `StaticEmbeddingHeadersProvider(None)`, i.e. + The default value is `EmbeddingAPIKeyHeaderProvider(None)`, i.e. no embedding-specific headers, whereas if the collection is configured with an embedding service other choices for this parameter can be meaningfully supplied. is configured for the collection, """ embedding_api_key: EmbeddingHeadersProvider = field( - default_factory=lambda: StaticEmbeddingHeadersProvider(None) + default_factory=lambda: EmbeddingAPIKeyHeaderProvider(None) ) diff --git a/astrapy/authentication.py b/astrapy/authentication.py index 49174368..5e7b7d12 100644 --- a/astrapy/authentication.py +++ b/astrapy/authentication.py @@ -36,7 +36,7 @@ def coerce_embedding_headers_provider( if isinstance(embedding_api_key, EmbeddingHeadersProvider): return embedding_api_key else: - return StaticEmbeddingHeadersProvider(embedding_api_key) + return EmbeddingAPIKeyHeaderProvider(embedding_api_key) class TokenProvider(ABC): @@ -202,7 +202,7 @@ def get_headers(self) -> Dict[str, str]: ... -class StaticEmbeddingHeadersProvider(EmbeddingHeadersProvider): +class EmbeddingAPIKeyHeaderProvider(EmbeddingHeadersProvider): """ A "pass-through" header provider representing the single-header (typically "X-Embedding-Api-Key") auth scheme, in use by most of the @@ -217,9 +217,9 @@ class StaticEmbeddingHeadersProvider(EmbeddingHeadersProvider): >>> from astrapy import DataAPIClient >>> from astrapy.authentication import ( CollectionVectorServiceOptions, - StaticEmbeddingHeadersProvider, + EmbeddingAPIKeyHeaderProvider, ) - >>> my_emb_api_key = StaticEmbeddingHeadersProvider("abc012...") + >>> my_emb_api_key = EmbeddingAPIKeyHeaderProvider("abc012...") >>> service_options = CollectionVectorServiceOptions( ... provider="a-certain-provider", ... model_name="some-embedding-model", @@ -305,7 +305,11 @@ def __init__(self, *, embedding_access_id: str, embedding_secret_id: str) -> Non self.embedding_secret_id = embedding_secret_id def __repr__(self) -> str: - return f'{self.__class__.__name__}("{self.embedding_access_id[:3]}...", "{self.embedding_secret_id[:3]}...")' + return ( + f"{self.__class__.__name__}(embedding_access_id=" + f'"{self.embedding_access_id[:3]}...", ' + f'embedding_secret_id="{self.embedding_secret_id[:3]}...")' + ) def get_headers(self) -> Dict[str, str]: return { diff --git a/astrapy/client.py b/astrapy/client.py index 5002f00b..affd9ffe 100644 --- a/astrapy/client.py +++ b/astrapy/client.py @@ -16,13 +16,14 @@ import logging import re -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from astrapy.admin import ( api_endpoint_parser, build_api_endpoint, database_id_matcher, fetch_raw_database_info_from_id_token, + normalize_id_endpoint_parameters, parse_api_endpoint, parse_generic_api_url, ) @@ -205,8 +206,9 @@ def set_caller( def get_database( self, - id: str, + id: Optional[str] = None, *, + api_endpoint: Optional[str] = None, token: Optional[Union[str, TokenProvider]] = None, namespace: Optional[str] = None, region: Optional[str] = None, @@ -223,14 +225,16 @@ def get_database( to be effectively used; in other words, this invocation does not create the database, just the object instance. Actual admin work can be achieved by using the AstraDBAdmin object. + api_endpoint: a named alias for the `id` first (positional) parameter, + with the same meaning. It cannot be passed together with `id`. token: if supplied, is passed to the Database instead of the client token. This can be either a literal token string or a subclass of `astrapy.authentication.TokenProvider`. - namespace: if provided, is passed to the Database - (it is left to the default otherwise). + namespace: if provided, it is passed to the Database; otherwise + the Database class will apply an environment-specific default. region: the region to use for connecting to the database. The database must be located in that region. - The region cannot be specified when he API endoint is used as `id`. + The region cannot be specified when the API endoint is used as `id`. Note that if this parameter is not passed, and cannot be inferred from the API endpoint, an additional DevOps API request is made to determine the default region and use it subsequently. @@ -263,43 +267,44 @@ def get_database( # lazy importing here to avoid circular dependency from astrapy import Database + # id/endpoint parameter normalization + _id_or_endpoint = normalize_id_endpoint_parameters(id, api_endpoint) if self.environment in Environment.astra_db_values: # handle the "endpoint passed as id" case first: - if re.match(api_endpoint_parser, id): + if re.match(api_endpoint_parser, _id_or_endpoint): if region is not None: raise ValueError( "Parameter `region` not supported when supplying an API endpoint." ) # in this case max_time_ms is ignored (no calls take place) return self.get_database_by_api_endpoint( - api_endpoint=id, + api_endpoint=_id_or_endpoint, token=token, namespace=namespace, api_path=api_path, api_version=api_version, ) else: - # need to inspect for values? - this_db_info: Optional[Dict[str, Any]] = None # handle overrides. Only region is needed (namespace can stay empty) if region: _region = region else: - if this_db_info is None: - logger.info(f"fetching raw database info for {id}") - this_db_info = fetch_raw_database_info_from_id_token( - id=id, - token=self.token_provider.get_token(), - environment=self.environment, - max_time_ms=max_time_ms, - ) - logger.info(f"finished fetching raw database info for {id}") + logger.info(f"fetching raw database info for {_id_or_endpoint}") + this_db_info = fetch_raw_database_info_from_id_token( + id=_id_or_endpoint, + token=self.token_provider.get_token(), + environment=self.environment, + max_time_ms=max_time_ms, + ) + logger.info( + f"finished fetching raw database info for {_id_or_endpoint}" + ) _region = this_db_info["info"]["region"] _token = coerce_token_provider(token) or self.token_provider _api_endpoint = build_api_endpoint( environment=self.environment, - database_id=id, + database_id=_id_or_endpoint, region=_region, ) return Database( @@ -315,13 +320,13 @@ def get_database( else: # in this case, this call is an alias for get_database_by_api_endpoint # - max_time_ms ignored - # - assume `id` is actually the endpoint + # - assume `_id_or_endpoint` is actually the endpoint if region is not None: raise ValueError( "Parameter `region` not supported outside of Astra DB." ) return self.get_database_by_api_endpoint( - api_endpoint=id, + api_endpoint=_id_or_endpoint, token=token, namespace=namespace, api_path=api_path, @@ -330,8 +335,9 @@ def get_database( def get_async_database( self, - id: str, + id: Optional[str] = None, *, + api_endpoint: Optional[str] = None, token: Optional[Union[str, TokenProvider]] = None, namespace: Optional[str] = None, region: Optional[str] = None, @@ -348,6 +354,7 @@ def get_async_database( return self.get_database( id=id, + api_endpoint=api_endpoint, token=token, namespace=namespace, region=region, @@ -378,8 +385,8 @@ def get_database_by_api_endpoint( token: if supplied, is passed to the Database instead of the client token. This can be either a literal token string or a subclass of `astrapy.authentication.TokenProvider`. - namespace: if provided, is passed to the Database - (it is left to the default otherwise). + namespace: if provided, it is passed to the Database; otherwise + the Database class will apply an environment-specific default. api_path: path to append to the API Endpoint. In typical usage, this should be left to its default of "/api/json". api_version: version specifier to append to the API path. In typical diff --git a/astrapy/collection.py b/astrapy/collection.py index a72ebca9..70c6de94 100644 --- a/astrapy/collection.py +++ b/astrapy/collection.py @@ -343,7 +343,7 @@ def with_options( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -410,7 +410,7 @@ def to_async( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -561,7 +561,10 @@ def namespace(self) -> str: 'default_keyspace' """ - return self.database.namespace + _namespace = self.database.namespace + if _namespace is None: + raise ValueError("The collection's DB is set with namespace=None") + return _namespace @property def name(self) -> str: @@ -2761,7 +2764,7 @@ def with_options( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -2828,7 +2831,7 @@ def to_sync( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -2981,7 +2984,10 @@ def namespace(self) -> str: 'default_keyspace' """ - return self.database.namespace + _namespace = self.database.namespace + if _namespace is None: + raise ValueError("The collection's DB is set with namespace=None") + return _namespace @property def name(self) -> str: diff --git a/astrapy/core/__init__.py b/astrapy/core/__init__.py index 2c9ca172..84497ed1 100644 --- a/astrapy/core/__init__.py +++ b/astrapy/core/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/astrapy/core/db.py b/astrapy/core/db.py index 10339f84..21b4e8a2 100644 --- a/astrapy/core/db.py +++ b/astrapy/core/db.py @@ -2384,13 +2384,20 @@ async def concurrent_insert_many( async with sem: logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}") try: - return await self.insert_many( + im_result = await self.insert_many( documents=docs, options=options, partial_failures_allowed=partial_failures_allowed, timeout_info=timeout_info, ) + logger.debug( + f"Finished processing chunk #{index + 1} of size {len(docs)}" + ) + return im_result except APIRequestError as e: + logger.debug( + f"Got APIRequestError while processing chunk #{index + 1} of size {len(docs)}" + ) if partial_failures_allowed: return e else: @@ -2400,7 +2407,9 @@ async def concurrent_insert_many( tasks = [ asyncio.create_task( concurrent_insert_many( - documents[i : i + chunk_size], i, partial_failures_allowed + documents[i : i + chunk_size], + i // chunk_size, + partial_failures_allowed, ) ) for i in range(0, len(documents), chunk_size) @@ -2411,7 +2420,9 @@ async def concurrent_insert_many( # "sequential strictly obeys fail-fast if ordered and concurrency==1" results = [ await concurrent_insert_many( - documents[i : i + chunk_size], i, partial_failures_allowed + documents[i : i + chunk_size], + i // chunk_size, + partial_failures_allowed, ) for i in range(0, len(documents), chunk_size) ] diff --git a/astrapy/database.py b/astrapy/database.py index 3e572844..cf404a2c 100644 --- a/astrapy/database.py +++ b/astrapy/database.py @@ -53,6 +53,8 @@ from astrapy.collection import AsyncCollection, Collection +DEFAULT_ASTRA_DB_NAMESPACE = "default_keyspace" + logger = logging.getLogger(__name__) @@ -118,7 +120,10 @@ class Database: `astrapy.authentication.TokenProvider`. namespace: this is the namespace all method calls will target, unless one is explicitly specified in the call. If no namespace is supplied - when creating a Database, the name "default_namespace" is set. + when creating a Database, on Astra DB the name "default_namespace" is set, + while on other environments the namespace is left unspecified: in this case, + most operations are unavailable until a namespace is set (through an explicit + `use_namespace` invocation or equivalent). caller_name: name of the application, or framework, on behalf of which the Data API calls are performed. This ends up in the request user-agent. caller_version: version of the caller. @@ -160,23 +165,28 @@ def __init__( _api_path: Optional[str] _api_version: Optional[str] if api_path is None: - _api_path = API_PATH_ENV_MAP.get(self.environment) + _api_path = API_PATH_ENV_MAP[self.environment] else: _api_path = api_path if api_version is None: - _api_version = API_VERSION_ENV_MAP.get(self.environment) + _api_version = API_VERSION_ENV_MAP[self.environment] else: _api_version = api_version self.token_provider = coerce_token_provider(token) - self._astra_db = AstraDB( - token=self.token_provider.get_token(), - api_endpoint=api_endpoint, - api_path=_api_path, - api_version=_api_version, - namespace=namespace, - caller_name=caller_name, - caller_version=caller_version, - ) + self.api_endpoint = api_endpoint.strip("/") + self.api_path = _api_path + self.api_version = _api_version + + # enforce defaults if on Astra DB: + self.using_namespace: Optional[str] + if namespace is None and self.environment in Environment.astra_db_values: + self.using_namespace = DEFAULT_ASTRA_DB_NAMESPACE + else: + self.using_namespace = namespace + + self.caller_name = caller_name + self.caller_version = caller_version + self._astra_db = self._refresh_astra_db() self._name: Optional[str] = None def __getattr__(self, collection_name: str) -> Collection: @@ -186,17 +196,41 @@ def __getitem__(self, collection_name: str) -> Collection: return self.get_collection(name=collection_name) def __repr__(self) -> str: + namespace_desc = self.namespace if self.namespace is not None else "(not set)" return ( - f'{self.__class__.__name__}(api_endpoint="{self._astra_db.api_endpoint}", ' - f'token="{str(self.token_provider)[:12]}...", namespace="{self._astra_db.namespace}")' + f'{self.__class__.__name__}(api_endpoint="{self.api_endpoint}", ' + f'token="{str(self.token_provider)[:12]}...", namespace="{namespace_desc}")' ) def __eq__(self, other: Any) -> bool: if isinstance(other, Database): - return self._astra_db == other._astra_db + return all( + [ + self.token_provider == other.token_provider, + self.api_endpoint == other.api_endpoint, + self.api_path == other.api_path, + self.api_version == other.api_version, + self.namespace == other.namespace, + self.caller_name == other.caller_name, + self.caller_version == other.caller_version, + ] + ) else: return False + def _refresh_astra_db(self) -> AstraDB: + """Re-instantiate a new (core) client based on the instance attributes.""" + logger.info("Instantiating a new (core) AstraDB") + return AstraDB( + token=self.token_provider.get_token(), + api_endpoint=self.api_endpoint, + api_path=self.api_path, + api_version=self.api_version, + namespace=self.namespace, + caller_name=self.caller_name, + caller_version=self.caller_version, + ) + def _copy( self, *, @@ -210,14 +244,14 @@ def _copy( api_version: Optional[str] = None, ) -> Database: return Database( - api_endpoint=api_endpoint or self._astra_db.api_endpoint, + api_endpoint=api_endpoint or self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, - namespace=namespace or self._astra_db.namespace, - caller_name=caller_name or self._astra_db.caller_name, - caller_version=caller_version or self._astra_db.caller_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, environment=environment or self.environment, - api_path=api_path or self._astra_db.api_path, - api_version=api_version or self._astra_db.api_version, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, ) def with_options( @@ -301,14 +335,14 @@ def to_async( """ return AsyncDatabase( - api_endpoint=api_endpoint or self._astra_db.api_endpoint, + api_endpoint=api_endpoint or self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, - namespace=namespace or self._astra_db.namespace, - caller_name=caller_name or self._astra_db.caller_name, - caller_version=caller_version or self._astra_db.caller_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, environment=environment or self.environment, - api_path=api_path or self._astra_db.api_path, - api_version=api_version or self._astra_db.api_version, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, ) def set_caller( @@ -330,10 +364,34 @@ def set_caller( """ logger.info(f"setting caller to {caller_name}/{caller_version}") - self._astra_db.set_caller( - caller_name=caller_name, - caller_version=caller_version, - ) + self.caller_name = caller_name + self.caller_version = caller_version + self._astra_db = self._refresh_astra_db() + + def use_namespace(self, namespace: str) -> None: + """ + Switch to a new working namespace for this database. + This method changes (mutates) the Database instance. + + Note that this method does not create the namespace, which should exist + already (created for instance with a `DatabaseAdmin.create_namespace` call). + + Args: + namespace: the new namespace to use as the database working namespace. + + Returns: + None. + + Example: + >>> my_db.list_collection_names() + ['coll_1', 'coll_2'] + >>> my_db.use_namespace("an_empty_namespace") + >>> my_db.list_collection_names() + [] + """ + logger.info(f"switching to namespace '{namespace}'") + self.using_namespace = namespace + self._astra_db = self._refresh_astra_db() def info(self) -> DatabaseInfo: """ @@ -357,7 +415,7 @@ def info(self) -> DatabaseInfo: logger.info("getting database info") database_info = fetch_database_info( - self._astra_db.api_endpoint, + self.api_endpoint, token=self.token_provider.get_token(), namespace=self.namespace, ) @@ -379,7 +437,7 @@ def id(self) -> str: '01234567-89ab-cdef-0123-456789abcdef' """ - parsed_api_endpoint = parse_api_endpoint(self._astra_db.api_endpoint) + parsed_api_endpoint = parse_api_endpoint(self.api_endpoint) if parsed_api_endpoint is not None: return parsed_api_endpoint.database_id else: @@ -405,17 +463,20 @@ def name(self) -> str: return self._name @property - def namespace(self) -> str: + def namespace(self) -> Optional[str]: """ The namespace this database uses as target for all commands when no method-call-specific namespace is specified. + Returns: + the working namespace (a string), or None if not set. + Example: >>> my_db.namespace 'the_keyspace' """ - return self._astra_db.namespace + return self.using_namespace def get_collection( self, @@ -444,7 +505,7 @@ def get_collection( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -477,7 +538,12 @@ def get_collection( # lazy importing here against circular-import error from astrapy.collection import Collection - _namespace = namespace or self._astra_db.namespace + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) return Collection( self, name, @@ -552,7 +618,7 @@ def create_collection( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -612,11 +678,18 @@ def create_collection( else: existing_names = [] - driver_db = self._astra_db.copy(namespace=namespace) + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + + driver_db = self._astra_db.copy(namespace=_namespace) if name in existing_names: raise CollectionAlreadyExistsException( text=f"CollectionInvalid: collection {name} already exists", - namespace=driver_db.namespace, + namespace=_namespace, collection_name=name, ) @@ -688,6 +761,11 @@ def drop_collection( logger.info(f"finished dropping collection '{_name}'") return dc_response.get("status", {}) # type: ignore[no-any-return] else: + if self.namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) logger.info(f"dropping collection '{name_or_collection}'") dc_response = self._astra_db.delete_collection( name_or_collection, @@ -727,12 +805,16 @@ def list_collections( CollectionDescriptor(name='my_v_col', options=CollectionOptions()) """ - if namespace: - _client = self._astra_db.copy(namespace=namespace) - else: - _client = self._astra_db + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + + driver_db = self._astra_db.copy(namespace=_namespace) logger.info("getting collections") - gc_response = _client.get_collections( + gc_response = driver_db.get_collections( options={"explain": True}, timeout_info=base_timeout_info(max_time_ms) ) if "collections" not in gc_response.get("status", {}): @@ -744,7 +826,7 @@ def list_collections( # we know this is a list of dicts, to marshal into "descriptors" logger.info("finished getting collections") return CommandCursor( - address=self._astra_db.base_url, + address=driver_db.base_url, items=[ CollectionDescriptor.from_dict(col_dict) for col_dict in gc_response["status"]["collections"] @@ -774,8 +856,15 @@ def list_collection_names( ['a_collection', 'another_col'] """ + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + logger.info("getting collection names") - gc_response = self._astra_db.copy(namespace=namespace).get_collections( + gc_response = self._astra_db.copy(namespace=_namespace).get_collections( timeout_info=base_timeout_info(max_time_ms) ) if "collections" not in gc_response.get("status", {}): @@ -820,12 +909,16 @@ def command( {'status': {'count': 123}} """ - if namespace: - _client = self._astra_db.copy(namespace=namespace) - else: - _client = self._astra_db + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + + driver_db = self._astra_db.copy(namespace=_namespace) if collection_name: - _collection = _client.collection(collection_name) + _collection = driver_db.collection(collection_name) logger.info(f"issuing custom command to API (on '{collection_name}')") req_response = _collection.post_raw_request( body=body, @@ -837,7 +930,7 @@ def command( return req_response else: logger.info("issuing custom command to API") - req_response = _client.post_raw_request( + req_response = driver_db.post_raw_request( body=body, timeout_info=base_timeout_info(max_time_ms), ) @@ -890,13 +983,15 @@ def get_database_admin( from astrapy.admin import AstraDBDatabaseAdmin, DataAPIDatabaseAdmin if self.environment in Environment.astra_db_values: - return AstraDBDatabaseAdmin.from_api_endpoint( - api_endpoint=self._astra_db.api_endpoint, + return AstraDBDatabaseAdmin( + api_endpoint=self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, - caller_name=self._astra_db.caller_name, - caller_version=self._astra_db.caller_version, + environment=self.environment, + caller_name=self.caller_name, + caller_version=self.caller_version, dev_ops_url=dev_ops_url, dev_ops_api_version=dev_ops_api_version, + spawner_database=self, ) else: if dev_ops_url is not None: @@ -908,13 +1003,14 @@ def get_database_admin( "Parameter `dev_ops_api_version` not supported outside of Astra DB." ) return DataAPIDatabaseAdmin( - api_endpoint=self._astra_db.api_endpoint, + api_endpoint=self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, environment=self.environment, - api_path=self._astra_db.api_path, - api_version=self._astra_db.api_version, - caller_name=self._astra_db.caller_name, - caller_version=self._astra_db.caller_version, + api_path=self.api_path, + api_version=self.api_version, + caller_name=self.caller_name, + caller_version=self.caller_version, + spawner_database=self, ) @@ -939,7 +1035,10 @@ class AsyncDatabase: `astrapy.authentication.TokenProvider`. namespace: this is the namespace all method calls will target, unless one is explicitly specified in the call. If no namespace is supplied - when creating a Database, the name "default_namespace" is set. + when creating a Database, on Astra DB the name "default_namespace" is set, + while on other environments the namespace is left unspecified: in this case, + most operations are unavailable until a namespace is set (through an explicit + `use_namespace` invocation or equivalent). caller_name: name of the application, or framework, on behalf of which the Data API calls are performed. This ends up in the request user-agent. caller_version: version of the caller. @@ -981,24 +1080,29 @@ def __init__( _api_path: Optional[str] _api_version: Optional[str] if api_path is None: - _api_path = API_PATH_ENV_MAP.get(self.environment) + _api_path = API_PATH_ENV_MAP[self.environment] else: _api_path = api_path if api_version is None: - _api_version = API_VERSION_ENV_MAP.get(self.environment) + _api_version = API_VERSION_ENV_MAP[self.environment] else: _api_version = api_version # self.token_provider = coerce_token_provider(token) - self._astra_db = AsyncAstraDB( - token=self.token_provider.get_token(), - api_endpoint=api_endpoint, - api_path=_api_path, - api_version=_api_version, - namespace=namespace, - caller_name=caller_name, - caller_version=caller_version, - ) + self.api_endpoint = api_endpoint.strip("/") + self.api_path = _api_path + self.api_version = _api_version + + # enforce defaults if on Astra DB: + self.using_namespace: Optional[str] + if namespace is None and self.environment in Environment.astra_db_values: + self.using_namespace = DEFAULT_ASTRA_DB_NAMESPACE + else: + self.using_namespace = namespace + + self.caller_name = caller_name + self.caller_version = caller_version + self._astra_db = self._refresh_astra_db() self._name: Optional[str] = None def __getattr__(self, collection_name: str) -> AsyncCollection: @@ -1008,14 +1112,25 @@ def __getitem__(self, collection_name: str) -> AsyncCollection: return self.to_sync().get_collection(name=collection_name).to_async() def __repr__(self) -> str: + namespace_desc = self.namespace if self.namespace is not None else "(not set)" return ( - f'{self.__class__.__name__}(api_endpoint="{self._astra_db.api_endpoint}", ' - f'token="{str(self.token_provider)[:12]}...", namespace="{self._astra_db.namespace}")' + f'{self.__class__.__name__}(api_endpoint="{self.api_endpoint}", ' + f'token="{str(self.token_provider)[:12]}...", namespace="{namespace_desc}")' ) def __eq__(self, other: Any) -> bool: if isinstance(other, AsyncDatabase): - return self._astra_db == other._astra_db + return all( + [ + self.token_provider == other.token_provider, + self.api_endpoint == other.api_endpoint, + self.api_path == other.api_path, + self.api_version == other.api_version, + self.namespace == other.namespace, + self.caller_name == other.caller_name, + self.caller_version == other.caller_version, + ] + ) else: return False @@ -1034,6 +1149,19 @@ async def __aexit__( traceback=traceback, ) + def _refresh_astra_db(self) -> AsyncAstraDB: + """Re-instantiate a new (core) client based on the instance attributes.""" + logger.info("Instantiating a new (core) AsyncAstraDB") + return AsyncAstraDB( + token=self.token_provider.get_token(), + api_endpoint=self.api_endpoint, + api_path=self.api_path, + api_version=self.api_version, + namespace=self.namespace, + caller_name=self.caller_name, + caller_version=self.caller_version, + ) + def _copy( self, *, @@ -1047,14 +1175,14 @@ def _copy( api_version: Optional[str] = None, ) -> AsyncDatabase: return AsyncDatabase( - api_endpoint=api_endpoint or self._astra_db.api_endpoint, + api_endpoint=api_endpoint or self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, - namespace=namespace or self._astra_db.namespace, - caller_name=caller_name or self._astra_db.caller_name, - caller_version=caller_version or self._astra_db.caller_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, environment=environment or self.environment, - api_path=api_path or self._astra_db.api_path, - api_version=api_version or self._astra_db.api_version, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, ) def with_options( @@ -1139,14 +1267,14 @@ def to_sync( """ return Database( - api_endpoint=api_endpoint or self._astra_db.api_endpoint, + api_endpoint=api_endpoint or self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, - namespace=namespace or self._astra_db.namespace, - caller_name=caller_name or self._astra_db.caller_name, - caller_version=caller_version or self._astra_db.caller_version, + namespace=namespace or self.namespace, + caller_name=caller_name or self.caller_name, + caller_version=caller_version or self.caller_version, environment=environment or self.environment, - api_path=api_path or self._astra_db.api_path, - api_version=api_version or self._astra_db.api_version, + api_path=api_path or self.api_path, + api_version=api_version or self.api_version, ) def set_caller( @@ -1168,10 +1296,34 @@ def set_caller( """ logger.info(f"setting caller to {caller_name}/{caller_version}") - self._astra_db.set_caller( - caller_name=caller_name, - caller_version=caller_version, - ) + self.caller_name = caller_name + self.caller_version = caller_version + self._astra_db = self._refresh_astra_db() + + def use_namespace(self, namespace: str) -> None: + """ + Switch to a new working namespace for this database. + This method changes (mutates) the AsyncDatabase instance. + + Note that this method does not create the namespace, which should exist + already (created for instance with a `DatabaseAdmin.async_create_namespace` call). + + Args: + namespace: the new namespace to use as the database working namespace. + + Returns: + None. + + Example: + >>> asyncio.run(my_async_db.list_collection_names()) + ['coll_1', 'coll_2'] + >>> my_async_db.use_namespace("an_empty_namespace") + >>> asyncio.run(my_async_db.list_collection_names()) + [] + """ + logger.info(f"switching to namespace '{namespace}'") + self.using_namespace = namespace + self._astra_db = self._refresh_astra_db() def info(self) -> DatabaseInfo: """ @@ -1195,7 +1347,7 @@ def info(self) -> DatabaseInfo: logger.info("getting database info") database_info = fetch_database_info( - self._astra_db.api_endpoint, + self.api_endpoint, token=self.token_provider.get_token(), namespace=self.namespace, ) @@ -1217,7 +1369,7 @@ def id(self) -> str: '01234567-89ab-cdef-0123-456789abcdef' """ - parsed_api_endpoint = parse_api_endpoint(self._astra_db.api_endpoint) + parsed_api_endpoint = parse_api_endpoint(self.api_endpoint) if parsed_api_endpoint is not None: return parsed_api_endpoint.database_id else: @@ -1243,17 +1395,20 @@ def name(self) -> str: return self._name @property - def namespace(self) -> str: + def namespace(self) -> Optional[str]: """ The namespace this database uses as target for all commands when no method-call-specific namespace is specified. + Returns: + the working namespace (a string), or None if not set. + Example: >>> my_async_db.namespace 'the_keyspace' """ - return self._astra_db.namespace + return self.using_namespace async def get_collection( self, @@ -1282,7 +1437,7 @@ async def get_collection( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -1318,7 +1473,12 @@ async def get_collection( # lazy importing here against circular-import error from astrapy.collection import AsyncCollection - _namespace = namespace or self._astra_db.namespace + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) return AsyncCollection( self, name, @@ -1393,7 +1553,7 @@ async def create_collection( each Data API call will include the necessary embedding-related headers as specified by this parameter. If a string is passed, it translates into the one "embedding api key" header - (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + (i.e. `astrapy.authentication.EmbeddingAPIKeyHeaderProvider`). For some vectorize providers/models, if using header-based authentication, specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` should be supplied. @@ -1456,11 +1616,19 @@ async def create_collection( ) else: existing_names = [] - driver_db = self._astra_db.copy(namespace=namespace) + + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + + driver_db = self._astra_db.copy(namespace=_namespace) if name in existing_names: raise CollectionAlreadyExistsException( text=f"CollectionInvalid: collection {name} already exists", - namespace=driver_db.namespace, + namespace=_namespace, collection_name=name, ) @@ -1534,6 +1702,11 @@ async def drop_collection( logger.info(f"finished dropping collection '{_name}'") return dc_response.get("status", {}) # type: ignore[no-any-return] else: + if self.namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) logger.info(f"dropping collection '{name_or_collection}'") dc_response = await self._astra_db.delete_collection( name_or_collection, @@ -1575,13 +1748,16 @@ def list_collections( * coll: CollectionDescriptor(name='my_v_col', options=CollectionOptions()) """ - _client: AsyncAstraDB - if namespace: - _client = self._astra_db.copy(namespace=namespace) - else: - _client = self._astra_db + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + + driver_db = self._astra_db.copy(namespace=_namespace) logger.info("getting collections") - gc_response = _client.to_sync().get_collections( + gc_response = driver_db.to_sync().get_collections( options={"explain": True}, timeout_info=base_timeout_info(max_time_ms), ) @@ -1594,7 +1770,7 @@ def list_collections( # we know this is a list of dicts, to marshal into "descriptors" logger.info("finished getting collections") return AsyncCommandCursor( - address=self._astra_db.base_url, + address=driver_db.base_url, items=[ CollectionDescriptor.from_dict(col_dict) for col_dict in gc_response["status"]["collections"] @@ -1624,8 +1800,15 @@ async def list_collection_names( ['a_collection', 'another_col'] """ + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + logger.info("getting collection names") - gc_response = await self._astra_db.copy(namespace=namespace).get_collections( + gc_response = await self._astra_db.copy(namespace=_namespace).get_collections( timeout_info=base_timeout_info(max_time_ms) ) if "collections" not in gc_response.get("status", {}): @@ -1673,12 +1856,16 @@ async def command( {'status': {'count': 123}} """ - if namespace: - _client = self._astra_db.copy(namespace=namespace) - else: - _client = self._astra_db + _namespace = namespace or self.namespace + if _namespace is None: + raise ValueError( + "No namespace specified. This operation requires a namespace to " + "be set, e.g. through the `use_namespace` method." + ) + + driver_db = self._astra_db.copy(namespace=_namespace) if collection_name: - _collection = await _client.collection(collection_name) + _collection = await driver_db.collection(collection_name) logger.info(f"issuing custom command to API (on '{collection_name}')") req_response = await _collection.post_raw_request( body=body, @@ -1690,7 +1877,7 @@ async def command( return req_response else: logger.info("issuing custom command to API") - req_response = await _client.post_raw_request( + req_response = await driver_db.post_raw_request( body=body, timeout_info=base_timeout_info(max_time_ms), ) @@ -1743,13 +1930,15 @@ def get_database_admin( from astrapy.admin import AstraDBDatabaseAdmin, DataAPIDatabaseAdmin if self.environment in Environment.astra_db_values: - return AstraDBDatabaseAdmin.from_api_endpoint( - api_endpoint=self._astra_db.api_endpoint, + return AstraDBDatabaseAdmin( + api_endpoint=self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, - caller_name=self._astra_db.caller_name, - caller_version=self._astra_db.caller_version, + environment=self.environment, + caller_name=self.caller_name, + caller_version=self.caller_version, dev_ops_url=dev_ops_url, dev_ops_api_version=dev_ops_api_version, + spawner_database=self, ) else: if dev_ops_url is not None: @@ -1761,11 +1950,12 @@ def get_database_admin( "Parameter `dev_ops_api_version` not supported outside of Astra DB." ) return DataAPIDatabaseAdmin( - api_endpoint=self._astra_db.api_endpoint, + api_endpoint=self.api_endpoint, token=coerce_token_provider(token) or self.token_provider, environment=self.environment, - api_path=self._astra_db.api_path, - api_version=self._astra_db.api_version, - caller_name=self._astra_db.caller_name, - caller_version=self._astra_db.caller_version, + api_path=self.api_path, + api_version=self.api_version, + caller_name=self.caller_name, + caller_version=self.caller_version, + spawner_database=self, ) diff --git a/astrapy/db/__init__.py b/astrapy/db/__init__.py index cce77543..763b50a3 100644 --- a/astrapy/db/__init__.py +++ b/astrapy/db/__init__.py @@ -14,6 +14,8 @@ """Core "db" subpackage, exported here to preserve import patterns.""" +from __future__ import annotations + from astrapy.core.db import ( AstraDB, AstraDBCollection, diff --git a/astrapy/info.py b/astrapy/info.py index d4c256e1..70479d2a 100644 --- a/astrapy/info.py +++ b/astrapy/info.py @@ -14,6 +14,7 @@ from __future__ import annotations +import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -27,7 +28,7 @@ class DatabaseInfo: Attributes: id: the database ID. region: the ID of the region through which the connection to DB is done. - namespace: the namespace this DB is set to work with. + namespace: the namespace this DB is set to work with. None if not set. name: the database name. Not necessarily unique: there can be multiple databases with the same name. environment: a label, whose value can be `Environment.PROD`, @@ -50,7 +51,7 @@ class DatabaseInfo: id: str region: str - namespace: str + namespace: Optional[str] name: str environment: str raw_info: Optional[Dict[str, Any]] @@ -435,3 +436,368 @@ def from_dict(raw_dict: Dict[str, Any]) -> CollectionDescriptor: options=CollectionOptions.from_dict(raw_dict.get("options") or {}), raw_descriptor=raw_dict, ) + + +@dataclass +class EmbeddingProviderParameter: + """ + A representation of a parameter as returned by the 'findEmbeddingProviders' + Data API endpoint. + + Attributes: + default_value: the default value for the parameter. + help: a textual description of the parameter. + name: the name to use when passing the parameter for vectorize operations. + required: whether the parameter is required or not. + parameter_type: a textual description of the data type for the parameter. + validation: a dictionary describing a parameter-specific validation policy. + """ + + default_value: Any + help: Optional[str] + name: str + required: bool + parameter_type: str + validation: Dict[str, Any] + + def __repr__(self) -> str: + return f"EmbeddingProviderParameter(name='{self.name}')" + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + "defaultValue": self.default_value, + "help": self.help, + "name": self.name, + "required": self.required, + "type": self.parameter_type, + "validation": self.validation, + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderParameter: + """ + Create an instance of EmbeddingProviderParameter from a dictionary + such as one from the Data API. + """ + + residual_keys = raw_dict.keys() - { + "defaultValue", + "help", + "name", + "required", + "type", + "validation", + } + if residual_keys: + warnings.warn( + "Unexpected key(s) encountered parsing a dictionary into " + f"an `EmbeddingProviderParameter`: '{','.join(sorted(residual_keys))}'" + ) + return EmbeddingProviderParameter( + default_value=raw_dict["defaultValue"], + help=raw_dict["help"], + name=raw_dict["name"], + required=raw_dict["required"], + parameter_type=raw_dict["type"], + validation=raw_dict["validation"], + ) + + +@dataclass +class EmbeddingProviderModel: + """ + A representation of an embedding model as returned by the 'findEmbeddingProviders' + Data API endpoint. + + Attributes: + name: the model name as must be passed when issuing + vectorize operations to the API. + parameters: a list of the `EmbeddingProviderParameter` objects the model admits. + vector_dimension: an integer for the dimensionality of the embedding model. + if this is None, the dimension can assume multiple values as specified + by a corresponding parameter listed with the model. + """ + + name: str + parameters: List[EmbeddingProviderParameter] + vector_dimension: Optional[int] + + def __repr__(self) -> str: + return f"EmbeddingProviderModel(name='{self.name}')" + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + "name": self.name, + "parameters": [parameter.as_dict() for parameter in self.parameters], + "vectorDimension": self.vector_dimension, + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderModel: + """ + Create an instance of EmbeddingProviderModel from a dictionary + such as one from the Data API. + """ + + residual_keys = raw_dict.keys() - { + "name", + "parameters", + "vectorDimension", + } + if residual_keys: + warnings.warn( + "Unexpected key(s) encountered parsing a dictionary into " + f"an `EmbeddingProviderModel`: '{','.join(sorted(residual_keys))}'" + ) + return EmbeddingProviderModel( + name=raw_dict["name"], + parameters=[ + EmbeddingProviderParameter.from_dict(param_dict) + for param_dict in raw_dict["parameters"] + ], + vector_dimension=raw_dict["vectorDimension"], + ) + + +@dataclass +class EmbeddingProviderToken: + """ + A representation of a "token", that is a specific secret string, needed by + an embedding model; this models a part of the response from the + 'findEmbeddingProviders' Data API endpoint. + + Attributes: + accepted: the name of this "token" as seen by the Data API. This is the + name that should be used in the clients when supplying the secret, + whether as header or by shared-secret. + forwarded: the name used by the API when issuing the embedding request + to the embedding provider. This is of no direct interest for the Data API user. + """ + + accepted: str + forwarded: str + + def __repr__(self) -> str: + return f"EmbeddingProviderToken('{self.accepted}')" + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + "accepted": self.accepted, + "forwarded": self.forwarded, + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderToken: + """ + Create an instance of EmbeddingProviderToken from a dictionary + such as one from the Data API. + """ + + residual_keys = raw_dict.keys() - { + "accepted", + "forwarded", + } + if residual_keys: + warnings.warn( + "Unexpected key(s) encountered parsing a dictionary into " + f"an `EmbeddingProviderToken`: '{','.join(sorted(residual_keys))}'" + ) + return EmbeddingProviderToken( + accepted=raw_dict["accepted"], + forwarded=raw_dict["forwarded"], + ) + + +@dataclass +class EmbeddingProviderAuthentication: + """ + A representation of an authentication mode for using an embedding model, + modeling the corresponding part of the response returned by the + 'findEmbeddingProviders' Data API endpoint (namely "supportedAuthentication"). + + Attributes: + enabled: whether this authentication mode is available for a given model. + tokens: a list of `EmbeddingProviderToken` objects, + detailing the secrets required for the authentication mode. + """ + + enabled: bool + tokens: List[EmbeddingProviderToken] + + def __repr__(self) -> str: + return ( + f"EmbeddingProviderAuthentication(enabled={self.enabled}, " + f"tokens={','.join(str(token) for token in self.tokens)})" + ) + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + "enabled": self.enabled, + "tokens": [token.as_dict() for token in self.tokens], + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderAuthentication: + """ + Create an instance of EmbeddingProviderAuthentication from a dictionary + such as one from the Data API. + """ + + residual_keys = raw_dict.keys() - { + "enabled", + "tokens", + } + if residual_keys: + warnings.warn( + "Unexpected key(s) encountered parsing a dictionary into " + f"an `EmbeddingProviderAuthentication`: '{','.join(sorted(residual_keys))}'" + ) + return EmbeddingProviderAuthentication( + enabled=raw_dict["enabled"], + tokens=[ + EmbeddingProviderToken.from_dict(token_dict) + for token_dict in raw_dict["tokens"] + ], + ) + + +@dataclass +class EmbeddingProvider: + """ + A representation of an embedding provider, as returned by the 'findEmbeddingProviders' + Data API endpoint. + + Attributes: + display_name: a version of the provider name for display and pretty printing. + Not to be used when issuing vectorize API requests (for the latter, it is + the key in the providers dictionary that is required). + models: a list of `EmbeddingProviderModel` objects pertaining to the provider. + parameters: a list of `EmbeddingProviderParameter` objects common to all models + for this provider. + supported_authentication: a dictionary of the authentication modes for + this provider. Note that disabled modes may still appear in this map, + albeit with the `enabled` property set to False. + url: a string template for the URL used by the Data API when issuing the request + toward the embedding provider. This is of no direct concern to the Data API user. + """ + + def __repr__(self) -> str: + return f"EmbeddingProvider(display_name='{self.display_name}', models={self.models})" + + display_name: Optional[str] + models: List[EmbeddingProviderModel] + parameters: List[EmbeddingProviderParameter] + supported_authentication: Dict[str, EmbeddingProviderAuthentication] + url: Optional[str] + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + "displayName": self.display_name, + "models": [model.as_dict() for model in self.models], + "parameters": [parameter.as_dict() for parameter in self.parameters], + "supportedAuthentication": { + sa_name: sa_value.as_dict() + for sa_name, sa_value in self.supported_authentication.items() + }, + "url": self.url, + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProvider: + """ + Create an instance of EmbeddingProvider from a dictionary + such as one from the Data API. + """ + + residual_keys = raw_dict.keys() - { + "displayName", + "models", + "parameters", + "supportedAuthentication", + "url", + } + if residual_keys: + warnings.warn( + "Unexpected key(s) encountered parsing a dictionary into " + f"an `EmbeddingProvider`: '{','.join(sorted(residual_keys))}'" + ) + return EmbeddingProvider( + display_name=raw_dict["displayName"], + models=[ + EmbeddingProviderModel.from_dict(model_dict) + for model_dict in raw_dict["models"] + ], + parameters=[ + EmbeddingProviderParameter.from_dict(param_dict) + for param_dict in raw_dict["parameters"] + ], + supported_authentication={ + sa_name: EmbeddingProviderAuthentication.from_dict(sa_dict) + for sa_name, sa_dict in raw_dict["supportedAuthentication"].items() + }, + url=raw_dict["url"], + ) + + +@dataclass +class FindEmbeddingProvidersResult: + """ + A representation of the whole response from the 'findEmbeddingProviders' + Data API endpoint. + + Attributes: + embedding_providers: a dictionary of provider names to EmbeddingProvider objects. + raw_info: a (nested) dictionary containing the original full response from the endpoint. + """ + + def __repr__(self) -> str: + return ( + "FindEmbeddingProvidersResult(embedding_providers=" + f"{', '.join(sorted(self.embedding_providers.keys()))})" + ) + + embedding_providers: Dict[str, EmbeddingProvider] + raw_info: Optional[Dict[str, Any]] + + def as_dict(self) -> Dict[str, Any]: + """Recast this object into a dictionary.""" + + return { + "embeddingProviders": { + ep_name: e_provider.as_dict() + for ep_name, e_provider in self.embedding_providers.items() + }, + } + + @staticmethod + def from_dict(raw_dict: Dict[str, Any]) -> FindEmbeddingProvidersResult: + """ + Create an instance of FindEmbeddingProvidersResult from a dictionary + such as one from the Data API. + """ + + residual_keys = raw_dict.keys() - { + "embeddingProviders", + } + if residual_keys: + warnings.warn( + "Unexpected key(s) encountered parsing a dictionary into " + f"a `FindEmbeddingProvidersResult`: '{','.join(sorted(residual_keys))}'" + ) + return FindEmbeddingProvidersResult( + raw_info=raw_dict, + embedding_providers={ + ep_name: EmbeddingProvider.from_dict(ep_body) + for ep_name, ep_body in raw_dict["embeddingProviders"].items() + }, + ) diff --git a/astrapy/ops/__init__.py b/astrapy/ops/__init__.py index 4797841b..8915312e 100644 --- a/astrapy/ops/__init__.py +++ b/astrapy/ops/__init__.py @@ -14,6 +14,8 @@ """Core "ops" subpackage, exported here to preserve import patterns.""" +from __future__ import annotations + from astrapy.core.ops import AstraDBOps __all__ = [ diff --git a/pyproject.toml b/pyproject.toml index 4a8f31f5..f99ddb92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "astrapy" -version = "1.3.2" +version = "1.4.0" description = "AstraPy is a Pythonic SDK for DataStax Astra and its Data API" authors = [ "Stefano Lottini ", diff --git a/scripts/astrapy_latest_interface.py b/scripts/astrapy_latest_interface.py index 2f2934d4..136d402e 100644 --- a/scripts/astrapy_latest_interface.py +++ b/scripts/astrapy_latest_interface.py @@ -16,8 +16,8 @@ api_endpoint = os.environ["ASTRA_DB_API_ENDPOINT"] # Initialize our vector db -my_client = astrapy.DataAPIClient(token) -my_database = my_client.get_database(api_endpoint) +my_client = astrapy.DataAPIClient() +my_database = my_client.get_database(api_endpoint, token=token) # In case we already have the collection, let's clear it out my_database.drop_collection("collection_test") @@ -29,8 +29,8 @@ my_collection.insert_one( { "_id": "1", - "name": "Coded Cleats Copy", - "description": "ChatGPT integrated sneakers that talk to you", + "name": "Coded Cleats", + "description": "GenAI-integrated sneakers that talk to you", "$vector": [0.25, 0.25, 0.25, 0.25, 0.25], }, ) diff --git a/tests/__init__.py b/tests/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/conftest.py b/tests/conftest.py index 2ccaa1ee..501bd9b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ Main conftest for shared fixtures (if any). """ +from __future__ import annotations + import functools import warnings from typing import Any, Awaitable, Callable, Optional, Tuple, TypedDict diff --git a/tests/core/__init__.py b/tests/core/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/core/__init__.py +++ b/tests/core/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/core/conftest.py b/tests/core/conftest.py index fbc109d8..7e9710a8 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -16,6 +16,8 @@ Test fixtures for 'core' testing """ +from __future__ import annotations + import math from typing import AsyncIterable, Dict, Iterable, List, Optional, Set, TypeVar diff --git a/tests/core/test_async_db_ddl.py b/tests/core/test_async_db_ddl.py index 6b395c28..5ae60d18 100644 --- a/tests/core/test_async_db_ddl.py +++ b/tests/core/test_async_db_ddl.py @@ -16,6 +16,8 @@ Tests for the `db.py` parts related to DML & client creation """ +from __future__ import annotations + import logging import pytest diff --git a/tests/core/test_async_db_dml.py b/tests/core/test_async_db_dml.py index 6231b1b4..5aa6aae3 100644 --- a/tests/core/test_async_db_dml.py +++ b/tests/core/test_async_db_dml.py @@ -17,6 +17,8 @@ (i.e. non `vector_*` methods) """ +from __future__ import annotations + import datetime import logging import uuid @@ -600,7 +602,6 @@ async def test_chunked_insert_many_failures( chunk_size=2, concurrency=1, ) - assert len((await async_empty_v_collection.find({}))["data"]["documents"]) == 2 await async_empty_v_collection.delete_many({}) with pytest.raises(APIRequestError): @@ -611,7 +612,24 @@ async def test_chunked_insert_many_failures( chunk_size=2, concurrency=2, ) - assert len((await async_empty_v_collection.find({}))["data"]["documents"]) >= 2 + + +@pytest.mark.describe( + "chunked_insert_many, failure modes unordered with concurrency (async)" +) +async def test_chunked_insert_many_failures_unordered_concurrent( + async_empty_v_collection: AsyncAstraDBCollection, +) -> None: + """ + Note: this is split from the `test_chunked_insert_many_failures` test + because the last command there (insert_many of bad_docs with ordered but concurrency>1) + leaves orphan running tasks once the exception is raised. This ultimately + traces back to known asyncio.gather behaviour - it is not a bug, + just something that makes the premises for the next testing ill-defined. + + More info on the underlying issue: https://stackoverflow.com/a/59074112/16545960 + """ + dup_docs = [{"_id": tid} for tid in ["a", "b", "b", "d", "e", "f"]] await async_empty_v_collection.delete_many({}) ins_result = await async_empty_v_collection.chunked_insert_many( diff --git a/tests/core/test_async_db_dml_pagination.py b/tests/core/test_async_db_dml_pagination.py index 0634fd62..09ac4878 100644 --- a/tests/core/test_async_db_dml_pagination.py +++ b/tests/core/test_async_db_dml_pagination.py @@ -16,6 +16,8 @@ Tests for the `db.py` parts on pagination primitives """ +from __future__ import annotations + import logging from typing import Optional diff --git a/tests/core/test_async_db_dml_vector.py b/tests/core/test_async_db_dml_vector.py index 353406cd..9e5baef3 100644 --- a/tests/core/test_async_db_dml_vector.py +++ b/tests/core/test_async_db_dml_vector.py @@ -16,6 +16,8 @@ Tests for the `db.py` parts on data manipulation `vector_*` methods """ +from __future__ import annotations + import logging from typing import Iterable, List, cast diff --git a/tests/core/test_conversions.py b/tests/core/test_conversions.py index 470b85a9..85405b3c 100644 --- a/tests/core/test_conversions.py +++ b/tests/core/test_conversions.py @@ -16,6 +16,8 @@ Tests for the User-Agent customization logic """ +from __future__ import annotations + import logging import pytest diff --git a/tests/core/test_db_ddl.py b/tests/core/test_db_ddl.py index 25f4fdde..dc737850 100644 --- a/tests/core/test_db_ddl.py +++ b/tests/core/test_db_ddl.py @@ -16,6 +16,8 @@ Tests for the `db.py` parts related to DML & client creation """ +from __future__ import annotations + import logging import pytest diff --git a/tests/core/test_db_dml.py b/tests/core/test_db_dml.py index f5f438ff..0abb3264 100644 --- a/tests/core/test_db_dml.py +++ b/tests/core/test_db_dml.py @@ -17,6 +17,8 @@ (i.e. non `vector_*` methods) """ +from __future__ import annotations + import datetime import json import logging @@ -590,7 +592,17 @@ def test_chunked_insert_many_failures( chunk_size=2, concurrency=2, ) - assert len(empty_v_collection.find({})["data"]["documents"]) >= 2 + + +@pytest.mark.describe("chunked_insert_many, failure modes unordered with concurrency") +def test_chunked_insert_many_failures_unordered_concurrent( + empty_v_collection: AstraDBCollection, +) -> None: + """ + Note: no real reason to split this from `test_chunked_insert_many_failures`, + other than to preserve symmetry with the async counterpart. + """ + dup_docs = [{"_id": tid} for tid in ["a", "b", "b", "d", "e", "f"]] empty_v_collection.delete_many({}) ins_result = empty_v_collection.chunked_insert_many( diff --git a/tests/core/test_db_dml_pagination.py b/tests/core/test_db_dml_pagination.py index 61df693f..30418d98 100644 --- a/tests/core/test_db_dml_pagination.py +++ b/tests/core/test_db_dml_pagination.py @@ -16,6 +16,8 @@ Tests for the `db.py` parts on pagination primitives """ +from __future__ import annotations + import logging from typing import Optional diff --git a/tests/core/test_db_dml_vector.py b/tests/core/test_db_dml_vector.py index a4eb3ad9..28532b3c 100644 --- a/tests/core/test_db_dml_vector.py +++ b/tests/core/test_db_dml_vector.py @@ -16,6 +16,8 @@ Tests for the `db.py` parts on data manipulation `vector_*` methods """ +from __future__ import annotations + import logging from typing import Iterable, List, cast diff --git a/tests/core/test_endpoint_parsing.py b/tests/core/test_endpoint_parsing.py index 7ee2b93d..27346134 100644 --- a/tests/core/test_endpoint_parsing.py +++ b/tests/core/test_endpoint_parsing.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy.admin import parse_generic_api_url diff --git a/tests/core/test_ids.py b/tests/core/test_ids.py index ae680c9b..ffcbb26b 100644 --- a/tests/core/test_ids.py +++ b/tests/core/test_ids.py @@ -16,6 +16,8 @@ Unit tests for the ObjectIds and UUIDn conversions """ +from __future__ import annotations + import json import uuid as lib_uuid diff --git a/tests/core/test_imports.py b/tests/core/test_imports.py index 8983822b..c03cc647 100644 --- a/tests/core/test_imports.py +++ b/tests/core/test_imports.py @@ -13,9 +13,11 @@ # limitations under the License. """ -Tests for the User-Agent customization logic +Tests for the core-related imports """ +from __future__ import annotations + import pytest diff --git a/tests/core/test_logging.py b/tests/core/test_logging.py index 7e039859..0d435763 100644 --- a/tests/core/test_logging.py +++ b/tests/core/test_logging.py @@ -16,6 +16,8 @@ Tests for the "TRACE" custom logging level """ +from __future__ import annotations + import logging import pytest diff --git a/tests/core/test_ops.py b/tests/core/test_ops.py index 95165d7b..b6ada837 100644 --- a/tests/core/test_ops.py +++ b/tests/core/test_ops.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools import logging from typing import Any, Dict, List, cast diff --git a/tests/core/test_timeouts.py b/tests/core/test_timeouts.py index e7417a39..a9a7753c 100644 --- a/tests/core/test_timeouts.py +++ b/tests/core/test_timeouts.py @@ -16,6 +16,8 @@ Tests for the `db.py` parts related to DML & client creation """ +from __future__ import annotations + import logging import httpx diff --git a/tests/core/test_user_agent.py b/tests/core/test_user_agent.py index f914ae56..d49dd5cc 100644 --- a/tests/core/test_user_agent.py +++ b/tests/core/test_user_agent.py @@ -16,6 +16,8 @@ Tests for the User-Agent customization logic """ +from __future__ import annotations + import logging import pytest diff --git a/tests/idiomatic/__init__.py b/tests/idiomatic/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/idiomatic/__init__.py +++ b/tests/idiomatic/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/idiomatic/conftest.py b/tests/idiomatic/conftest.py index 358a93ec..80f6076e 100644 --- a/tests/idiomatic/conftest.py +++ b/tests/idiomatic/conftest.py @@ -14,6 +14,8 @@ """Fixtures specific to the non-core-side testing.""" +from __future__ import annotations + from typing import Iterable import pytest diff --git a/tests/idiomatic/integration/__init__.py b/tests/idiomatic/integration/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/idiomatic/integration/__init__.py +++ b/tests/idiomatic/integration/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/idiomatic/integration/test_admin.py b/tests/idiomatic/integration/test_admin.py index ff4f04dc..fcc93f45 100644 --- a/tests/idiomatic/integration/test_admin.py +++ b/tests/idiomatic/integration/test_admin.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time from typing import Any, Awaitable, Callable, List, Optional, Tuple import pytest -from astrapy import DataAPIClient +from astrapy import AsyncDatabase, DataAPIClient, Database from astrapy.admin import API_ENDPOINT_TEMPLATE_MAP from ..conftest import ( @@ -582,3 +584,43 @@ async def _awaiter4() -> bool: max_seconds=DATABASE_TIMEOUT, acondition=_awaiter4, ) + + @pytest.mark.describe( + "test of the update_db_namespace flag for AstraDBDatabaseAdmin, sync" + ) + def test_astra_updatedbnamespace_sync(self, sync_database: Database) -> None: + NEW_NS_NAME_NOT_UPDATED = "tnudn_notupd" + NEW_NS_NAME_UPDATED = "tnudn_upd" + + namespace0 = sync_database.namespace + database_admin = sync_database.get_database_admin() + database_admin.create_namespace(NEW_NS_NAME_NOT_UPDATED) + assert sync_database.namespace == namespace0 + + database_admin.create_namespace(NEW_NS_NAME_UPDATED, update_db_namespace=True) + assert sync_database.namespace == NEW_NS_NAME_UPDATED + + database_admin.drop_namespace(NEW_NS_NAME_NOT_UPDATED) + database_admin.drop_namespace(NEW_NS_NAME_UPDATED) + + @pytest.mark.describe( + "test of the update_db_namespace flag for AstraDBDatabaseAdmin, async" + ) + async def test_astra_updatedbnamespace_async( + self, async_database: AsyncDatabase + ) -> None: + NEW_NS_NAME_NOT_UPDATED = "tnudn_notupd" + NEW_NS_NAME_UPDATED = "tnudn_upd" + + namespace0 = async_database.namespace + database_admin = async_database.get_database_admin() + await database_admin.async_create_namespace(NEW_NS_NAME_NOT_UPDATED) + assert async_database.namespace == namespace0 + + await database_admin.async_create_namespace( + NEW_NS_NAME_UPDATED, update_db_namespace=True + ) + assert async_database.namespace == NEW_NS_NAME_UPDATED + + await database_admin.async_drop_namespace(NEW_NS_NAME_NOT_UPDATED) + await database_admin.async_drop_namespace(NEW_NS_NAME_UPDATED) diff --git a/tests/idiomatic/integration/test_ddl_async.py b/tests/idiomatic/integration/test_ddl_async.py index 81ee17a9..93ca37e8 100644 --- a/tests/idiomatic/integration/test_ddl_async.py +++ b/tests/idiomatic/integration/test_ddl_async.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time import pytest @@ -220,6 +222,28 @@ async def test_database_list_collections_cross_namespace_async( namespace=data_api_credentials_info["secondary_namespace"] ) + @pytest.mark.skipif( + SECONDARY_NAMESPACE is None, reason="No secondary namespace provided" + ) + @pytest.mark.describe("test of Database use_namespace, async") + async def test_database_use_namespace_async( + self, + async_database: AsyncDatabase, + async_collection: AsyncCollection, + data_api_credentials_kwargs: DataAPICredentials, + data_api_credentials_info: DataAPICredentialsInfo, + ) -> None: + # make a copy to avoid mutating the fixture + at_database = async_database._copy() + assert at_database == async_database + assert at_database.namespace == data_api_credentials_kwargs["namespace"] + assert TEST_COLLECTION_NAME in await at_database.list_collection_names() + + at_database.use_namespace(data_api_credentials_info["secondary_namespace"]) # type: ignore[arg-type] + assert at_database != async_database + assert at_database.namespace == data_api_credentials_info["secondary_namespace"] + assert TEST_COLLECTION_NAME not in await at_database.list_collection_names() + @pytest.mark.skipif( SECONDARY_NAMESPACE is None, reason="No secondary namespace provided" ) @@ -310,6 +334,46 @@ async def test_tokenless_client_async( api_endpoint = data_api_credentials_kwargs["api_endpoint"] token = data_api_credentials_kwargs["token"] client = DataAPIClient(environment=data_api_credentials_info["environment"]) - a_database = client.get_async_database(api_endpoint, token=token) + a_database = client.get_async_database( + api_endpoint, + token=token, + namespace=data_api_credentials_kwargs["namespace"], + ) coll_names = await a_database.list_collection_names() assert isinstance(coll_names, list) + + @pytest.mark.skipif(not IS_ASTRA_DB, reason="Not supported outside of Astra DB") + @pytest.mark.describe( + "test database-from-admin default namespace per environment, async" + ) + async def test_database_from_admin_default_namespace_per_environment_async( + self, + data_api_credentials_kwargs: DataAPICredentials, + data_api_credentials_info: DataAPICredentialsInfo, + ) -> None: + client = DataAPIClient(environment=data_api_credentials_info["environment"]) + admin = client.get_admin(token=data_api_credentials_kwargs["token"]) + db_m = admin.get_async_database( + data_api_credentials_kwargs["api_endpoint"], + namespace="M", + ) + assert db_m.namespace == "M" + db_n = admin.get_async_database(data_api_credentials_kwargs["api_endpoint"]) + assert isinstance(db_n.namespace, str) # i.e. resolution took place + + @pytest.mark.skipif(not IS_ASTRA_DB, reason="Not supported outside of Astra DB") + @pytest.mark.describe( + "test database-from-astradbadmin default namespace per environment, async" + ) + async def test_database_from_astradbadmin_default_namespace_per_environment_async( + self, + data_api_credentials_kwargs: DataAPICredentials, + data_api_credentials_info: DataAPICredentialsInfo, + ) -> None: + client = DataAPIClient(environment=data_api_credentials_info["environment"]) + admin = client.get_admin(token=data_api_credentials_kwargs["token"]) + db_admin = admin.get_database_admin(data_api_credentials_kwargs["api_endpoint"]) + db_m = db_admin.get_async_database(namespace="M") + assert db_m.namespace == "M" + db_n = db_admin.get_async_database() + assert isinstance(db_n.namespace, str) # i.e. resolution took place diff --git a/tests/idiomatic/integration/test_ddl_sync.py b/tests/idiomatic/integration/test_ddl_sync.py index e42de4f6..0cc912c0 100644 --- a/tests/idiomatic/integration/test_ddl_sync.py +++ b/tests/idiomatic/integration/test_ddl_sync.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time import pytest from astrapy import Collection, DataAPIClient, Database +from astrapy.admin import AstraDBDatabaseAdmin, parse_api_endpoint from astrapy.constants import DefaultIdType, VectorMetric from astrapy.ids import UUID, ObjectId from astrapy.info import CollectionDescriptor, DatabaseInfo @@ -210,6 +213,28 @@ def test_database_list_collections_cross_namespace_sync( namespace=data_api_credentials_info["secondary_namespace"] ) + @pytest.mark.skipif( + SECONDARY_NAMESPACE is None, reason="No secondary namespace provided" + ) + @pytest.mark.describe("test of Database use_namespace, sync") + def test_database_use_namespace_sync( + self, + sync_database: Database, + sync_collection: Collection, + data_api_credentials_kwargs: DataAPICredentials, + data_api_credentials_info: DataAPICredentialsInfo, + ) -> None: + # make a copy to avoid mutating the fixture + t_database = sync_database._copy() + assert t_database == sync_database + assert t_database.namespace == data_api_credentials_kwargs["namespace"] + assert TEST_COLLECTION_NAME in t_database.list_collection_names() + + t_database.use_namespace(data_api_credentials_info["secondary_namespace"]) # type: ignore[arg-type] + assert t_database != sync_database + assert t_database.namespace == data_api_credentials_info["secondary_namespace"] + assert TEST_COLLECTION_NAME not in t_database.list_collection_names() + @pytest.mark.skipif( SECONDARY_NAMESPACE is None, reason="No secondary namespace provided" ) @@ -299,5 +324,87 @@ def test_tokenless_client_sync( api_endpoint = data_api_credentials_kwargs["api_endpoint"] token = data_api_credentials_kwargs["token"] client = DataAPIClient(environment=data_api_credentials_info["environment"]) - database = client.get_database(api_endpoint, token=token) + database = client.get_database( + api_endpoint, + token=token, + namespace=data_api_credentials_kwargs["namespace"], + ) assert isinstance(database.list_collection_names(), list) + + @pytest.mark.skipif(not IS_ASTRA_DB, reason="Not supported outside of Astra DB") + @pytest.mark.describe( + "test of autoregion through DevOps API for get_database(_admin), sync" + ) + def test_autoregion_getdatabase_sync( + self, + data_api_credentials_kwargs: DataAPICredentials, + data_api_credentials_info: DataAPICredentialsInfo, + ) -> None: + client = DataAPIClient(environment=data_api_credentials_info["environment"]) + parsed_api_endpoint = parse_api_endpoint( + data_api_credentials_kwargs["api_endpoint"] + ) + if parsed_api_endpoint is None: + raise ValueError( + f"Unparseable API endpoint: {data_api_credentials_kwargs['api_endpoint']}" + ) + adm = client.get_admin(token=data_api_credentials_kwargs["token"]) + # auto-region through the DebvOps "db info" call + assert adm.get_database_admin( + parsed_api_endpoint.database_id + ) == adm.get_database_admin(data_api_credentials_kwargs["api_endpoint"]) + + # auto-region for get_database + assert adm.get_database( + parsed_api_endpoint.database_id, + namespace="the_ns", + ) == adm.get_database( + data_api_credentials_kwargs["api_endpoint"], namespace="the_ns" + ) + + # auto-region for the init of AstraDBDatabaseAdmin + assert AstraDBDatabaseAdmin( + data_api_credentials_kwargs["api_endpoint"], + token=data_api_credentials_kwargs["token"], + environment=data_api_credentials_info["environment"], + ) == AstraDBDatabaseAdmin( + parsed_api_endpoint.database_id, + token=data_api_credentials_kwargs["token"], + environment=data_api_credentials_info["environment"], + ) + + @pytest.mark.skipif(not IS_ASTRA_DB, reason="Not supported outside of Astra DB") + @pytest.mark.describe( + "test database-from-admin default namespace per environment, sync" + ) + def test_database_from_admin_default_namespace_per_environment_sync( + self, + data_api_credentials_kwargs: DataAPICredentials, + data_api_credentials_info: DataAPICredentialsInfo, + ) -> None: + client = DataAPIClient(environment=data_api_credentials_info["environment"]) + admin = client.get_admin(token=data_api_credentials_kwargs["token"]) + db_m = admin.get_database( + data_api_credentials_kwargs["api_endpoint"], + namespace="M", + ) + assert db_m.namespace == "M" + db_n = admin.get_database(data_api_credentials_kwargs["api_endpoint"]) + assert isinstance(db_n.namespace, str) # i.e. resolution took place + + @pytest.mark.skipif(not IS_ASTRA_DB, reason="Not supported outside of Astra DB") + @pytest.mark.describe( + "test database-from-astradbadmin default namespace per environment, sync" + ) + def test_database_from_astradbadmin_default_namespace_per_environment_sync( + self, + data_api_credentials_kwargs: DataAPICredentials, + data_api_credentials_info: DataAPICredentialsInfo, + ) -> None: + client = DataAPIClient(environment=data_api_credentials_info["environment"]) + admin = client.get_admin(token=data_api_credentials_kwargs["token"]) + db_admin = admin.get_database_admin(data_api_credentials_kwargs["api_endpoint"]) + db_m = db_admin.get_database(namespace="M") + assert db_m.namespace == "M" + db_n = db_admin.get_database() + assert isinstance(db_n.namespace, str) # i.e. resolution took place diff --git a/tests/idiomatic/integration/test_dml_async.py b/tests/idiomatic/integration/test_dml_async.py index d1225b7a..76c2f366 100644 --- a/tests/idiomatic/integration/test_dml_async.py +++ b/tests/idiomatic/integration/test_dml_async.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import datetime from typing import Any, Dict, List diff --git a/tests/idiomatic/integration/test_dml_sync.py b/tests/idiomatic/integration/test_dml_sync.py index 37e5fac1..075e1171 100644 --- a/tests/idiomatic/integration/test_dml_sync.py +++ b/tests/idiomatic/integration/test_dml_sync.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import datetime from typing import Any, Dict, List diff --git a/tests/idiomatic/integration/test_exceptions_async.py b/tests/idiomatic/integration/test_exceptions_async.py index de4f2eaa..4b1a0508 100644 --- a/tests/idiomatic/integration/test_exceptions_async.py +++ b/tests/idiomatic/integration/test_exceptions_async.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import List import pytest diff --git a/tests/idiomatic/integration/test_exceptions_sync.py b/tests/idiomatic/integration/test_exceptions_sync.py index 43b7e06e..ec2c0a2c 100644 --- a/tests/idiomatic/integration/test_exceptions_sync.py +++ b/tests/idiomatic/integration/test_exceptions_sync.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy import Collection, Database diff --git a/tests/idiomatic/integration/test_nonastra_admin.py b/tests/idiomatic/integration/test_nonastra_admin.py index d3ad69ee..156a9514 100644 --- a/tests/idiomatic/integration/test_nonastra_admin.py +++ b/tests/idiomatic/integration/test_nonastra_admin.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy import AsyncDatabase, Database @@ -35,6 +37,7 @@ def test_nonastra_database_admin_sync(self, sync_database: Database) -> None: - database -> create_collection/list_collection_names - drop namespace - list namespaces, check + - ALSO: test the update_db_namespace flag when creating from db->dbadmin """ NEW_NS_NAME = "tnaa_test_ns_s" NEW_COLL_NAME = "tnaa_test_coll" @@ -59,6 +62,23 @@ def test_nonastra_database_admin_sync(self, sync_database: Database) -> None: namespaces3 = set(database_admin.list_namespaces()) assert namespaces3 == namespaces1 + # update_db_namespace: + NEW_NS_NAME_NOT_UPDATED = "tnudn_notupd" + NEW_NS_NAME_UPDATED = "tnudn_upd" + + database = sync_database._copy() + + namespace0 = database.namespace + database_admin = database.get_database_admin() + database_admin.create_namespace(NEW_NS_NAME_NOT_UPDATED) + assert database.namespace == namespace0 + + database_admin.create_namespace(NEW_NS_NAME_UPDATED, update_db_namespace=True) + assert database.namespace == NEW_NS_NAME_UPDATED + + database_admin.drop_namespace(NEW_NS_NAME_NOT_UPDATED) + database_admin.drop_namespace(NEW_NS_NAME_UPDATED) + @pytest.mark.describe( "test of the namespace crud with non-Astra DataAPIDatabaseAdmin, async" ) @@ -74,6 +94,7 @@ async def test_nonastra_database_admin_async( - database -> create_collection/list_collection_names - drop namespace - list namespaces, check + - ALSO: test the update_db_namespace flag when creating from db->dbadmin """ NEW_NS_NAME = "tnaa_test_ns_a" NEW_COLL_NAME = "tnaa_test_coll" @@ -97,3 +118,22 @@ async def test_nonastra_database_admin_async( namespaces3 = set(await database_admin.async_list_namespaces()) assert namespaces3 == namespaces1 + + # update_db_namespace: + NEW_NS_NAME_NOT_UPDATED = "tnudn_notupd" + NEW_NS_NAME_UPDATED = "tnudn_upd" + + adatabase = async_database._copy() + + namespace0 = adatabase.namespace + database_admin = adatabase.get_database_admin() + await database_admin.async_create_namespace(NEW_NS_NAME_NOT_UPDATED) + assert adatabase.namespace == namespace0 + + await database_admin.async_create_namespace( + NEW_NS_NAME_UPDATED, update_db_namespace=True + ) + assert adatabase.namespace == NEW_NS_NAME_UPDATED + + await database_admin.async_drop_namespace(NEW_NS_NAME_NOT_UPDATED) + await database_admin.async_drop_namespace(NEW_NS_NAME_UPDATED) diff --git a/tests/idiomatic/integration/test_timeout_async.py b/tests/idiomatic/integration/test_timeout_async.py index c08dce71..07b56e0f 100644 --- a/tests/idiomatic/integration/test_timeout_async.py +++ b/tests/idiomatic/integration/test_timeout_async.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import asyncio import pytest diff --git a/tests/idiomatic/integration/test_timeout_sync.py b/tests/idiomatic/integration/test_timeout_sync.py index de6a498e..aa22d232 100644 --- a/tests/idiomatic/integration/test_timeout_sync.py +++ b/tests/idiomatic/integration/test_timeout_sync.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time import pytest diff --git a/tests/idiomatic/unit/__init__.py b/tests/idiomatic/unit/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/idiomatic/unit/__init__.py +++ b/tests/idiomatic/unit/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/idiomatic/unit/test_admin_conversions.py b/tests/idiomatic/unit/test_admin_conversions.py index 432c7275..7538c05c 100644 --- a/tests/idiomatic/unit/test_admin_conversions.py +++ b/tests/idiomatic/unit/test_admin_conversions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import httpx import pytest @@ -163,8 +165,9 @@ def test_astradbadmin_conversions(self) -> None: ) def test_astradbdatabaseadmin_conversions(self) -> None: adda1 = AstraDBDatabaseAdmin( - "i1", + "01234567-89ab-cdef-0123-456789abcdef", token="t1", + region="reg", environment="dev", caller_name="cn", caller_version="cv", @@ -172,8 +175,9 @@ def test_astradbdatabaseadmin_conversions(self) -> None: dev_ops_api_version="dvv", ) adda2 = AstraDBDatabaseAdmin( - "i1", + "01234567-89ab-cdef-0123-456789abcdef", token="t1", + region="reg", environment="dev", caller_name="cn", caller_version="cv", @@ -182,16 +186,20 @@ def test_astradbdatabaseadmin_conversions(self) -> None: ) assert adda1 == adda2 - assert adda1 != adda1._copy(id="x") + assert adda1 != adda1._copy(id="99999999-89ab-cdef-0123-456789abcdef") assert adda1 != adda1._copy(token="x") + assert adda1 != adda1._copy(region="x") assert adda1 != adda1._copy(environment="test") assert adda1 != adda1._copy(caller_name="x") assert adda1 != adda1._copy(caller_version="x") assert adda1 != adda1._copy(dev_ops_url="x") assert adda1 != adda1._copy(dev_ops_api_version="x") - assert adda1 == adda1._copy(id="x")._copy(id="i1") + assert adda1 == adda1._copy(id="99999999-89ab-cdef-0123-456789abcdef")._copy( + id="01234567-89ab-cdef-0123-456789abcdef" + ) assert adda1 == adda1._copy(token="x")._copy(token="t1") + assert adda1 == adda1._copy(region="x")._copy(region="reg") assert adda1 == adda1._copy(environment="test")._copy(environment="dev") assert adda1 == adda1._copy(caller_name="x")._copy(caller_name="cn") assert adda1 == adda1._copy(caller_version="x")._copy(caller_version="cv") @@ -200,12 +208,14 @@ def test_astradbdatabaseadmin_conversions(self) -> None: dev_ops_api_version="dvv" ) - assert adda1 != adda1.with_options(id="x") + assert adda1 != adda1.with_options(id="99999999-89ab-cdef-0123-456789abcdef") assert adda1 != adda1.with_options(token="x") assert adda1 != adda1.with_options(caller_name="x") assert adda1 != adda1.with_options(caller_version="x") - assert adda1 == adda1.with_options(id="x").with_options(id="i1") + assert adda1 == adda1.with_options( + id="99999999-89ab-cdef-0123-456789abcdef" + ).with_options(id="01234567-89ab-cdef-0123-456789abcdef") assert adda1 == adda1.with_options(token="x").with_options(token="t1") assert adda1 == adda1.with_options(caller_name="x").with_options( caller_name="cn" @@ -269,11 +279,13 @@ def test_astradbadmin_token_inheritance(self) -> None: def test_astradbdatabaseadmin_token_inheritance(self) -> None: db_id_string = "01234567-89ab-cdef-0123-456789abcdef" adbadmin_t = AstraDBDatabaseAdmin( - db_id_string, token=StaticTokenProvider("static") + db_id_string, + token=StaticTokenProvider("static"), + region="reg", ) - adbadmin_0 = AstraDBDatabaseAdmin(db_id_string) + adbadmin_0 = AstraDBDatabaseAdmin(db_id_string, region="reg") token_f = UsernamePasswordTokenProvider(username="u", password="p") - adbadmin_f = AstraDBDatabaseAdmin(db_id_string, token=token_f) + adbadmin_f = AstraDBDatabaseAdmin(db_id_string, region="reg", token=token_f) assert adbadmin_t.get_database( token=token_f, namespace="n", region="r" @@ -339,3 +351,112 @@ def test_database_token_inheritance(self) -> None: a_database_0.get_database_admin(token=token_f) == a_database_f.get_database_admin() ) + + @pytest.mark.describe( + "test of id, endpoint, region normalization in get_database(_admin)" + ) + def test_param_normalize_getdatabase(self) -> None: + # the case of ID only is deferred to an integration test (it's impure) + api_ep = "https://01234567-89ab-cdef-0123-456789abcdef-the-region.apps.astra.datastax.com" + db_id = "01234567-89ab-cdef-0123-456789abcdef" + db_reg = "the-region" + + adm = AstraDBAdmin("t1") + + db_adm1 = adm.get_database_admin(db_id, region=db_reg) + db_adm2 = adm.get_database_admin(api_ep, region=db_reg) + db_adm3 = adm.get_database_admin(api_ep) + with pytest.raises(ValueError): + adm.get_database_admin(api_ep, region="not-that-one") + + assert db_adm1 == db_adm2 + assert db_adm2 == db_adm3 + + db_1 = adm.get_database(db_id, region=db_reg, namespace="the_ns") + db_2 = adm.get_database(api_ep, region=db_reg, namespace="the_ns") + db_3 = adm.get_database(api_ep, namespace="the_ns") + with pytest.raises(ValueError): + adm.get_database(api_ep, region="not-that-one", namespace="the_ns") + + assert db_1 == db_2 + assert db_2 == db_3 + + db_adm_m1 = AstraDBDatabaseAdmin(db_id, token="t", region=db_reg) + db_adm_m2 = AstraDBDatabaseAdmin(api_ep, token="t", region=db_reg) + db_adm_m3 = AstraDBDatabaseAdmin(api_ep, token="t") + with pytest.raises(ValueError): + AstraDBDatabaseAdmin(api_ep, token="t", region="not-that-one") + + assert db_adm_m1 == db_adm_m2 + assert db_adm_m1 == db_adm_m3 + + @pytest.mark.describe( + "test of region being deprecated in AstraDBDatabaseAdmin.get_database" + ) + def test_region_deprecation_astradbdatabaseadmin_getdatabase(self) -> None: + api_ep = "https://01234567-89ab-cdef-0123-456789abcdef-the-region.apps.astra.datastax.com" + db_adm = AstraDBDatabaseAdmin(api_ep) + with pytest.warns(DeprecationWarning): + db1 = db_adm.get_database( + region="another-region", namespace="the-namespace" + ) + # it's ignored anyway + assert db1 == db_adm.get_database(namespace="the-namespace") + + @pytest.mark.describe( + "test of spawner_database for AstraDBDatabaseAdmin if not provided" + ) + def test_spawnerdatabase_astradbdatabaseadmin_notprovided(self) -> None: + api_ep = "https://01234567-89ab-cdef-0123-456789abcdef-the-region.apps.astra.datastax.com" + db_adm = AstraDBDatabaseAdmin(api_ep) + assert db_adm.spawner_database.api_endpoint == api_ep + + @pytest.mark.describe( + "test of spawner_database for DataAPIDatabaseAdmin if not provided" + ) + def test_spawnerdatabase_dataapidatabaseadmin_notprovided(self) -> None: + api_ep = "http://aa" + db_adm = DataAPIDatabaseAdmin(api_ep) + assert db_adm.spawner_database.api_endpoint == api_ep + + @pytest.mark.describe( + "test of spawner_database for AstraDBDatabaseAdmin, sync db provided" + ) + def test_spawnerdatabase_astradbdatabaseadmin_syncprovided(self) -> None: + api_ep = "https://01234567-89ab-cdef-0123-456789abcdef-the-region.apps.astra.datastax.com" + db = Database(api_ep, namespace="M") + db_adm = AstraDBDatabaseAdmin(api_ep, spawner_database=db) + assert db_adm.spawner_database is db + + @pytest.mark.describe( + "test of spawner_database for AstraDBDatabaseAdmin, async db provided" + ) + def test_spawnerdatabase_astradbdatabaseadmin_asyncprovided(self) -> None: + api_ep = "https://01234567-89ab-cdef-0123-456789abcdef-the-region.apps.astra.datastax.com" + adb = AsyncDatabase(api_ep, namespace="M") + db_adm = AstraDBDatabaseAdmin(api_ep, spawner_database=adb) + assert db_adm.spawner_database is adb + + @pytest.mark.describe( + "test of spawner_database for DataAPIDatabaseAdmin, sync db provided" + ) + def test_spawnerdatabase_dataapidatabaseadmin_syncprovided(self) -> None: + api_ep = "http://aa" + db = Database(api_ep) + db_adm = DataAPIDatabaseAdmin(api_ep, spawner_database=db) + assert db_adm.spawner_database is db + + @pytest.mark.describe( + "test of spawner_database for DataAPIDatabaseAdmin, async db provided" + ) + def test_spawnerdatabase_dataapidatabaseadmin_asyncprovided(self) -> None: + api_ep = "http://aa" + adb = AsyncDatabase(api_ep) + db_adm = DataAPIDatabaseAdmin(api_ep, spawner_database=adb) + assert db_adm.spawner_database is adb + + @pytest.mark.describe("test of from_api_endpoint for AstraDBDatabaseAdmin") + def test_fromapiendpoint_astradbdatabaseadmin(self) -> None: + api_ep = "https://01234567-89ab-cdef-0123-456789abcdef-the-region.apps.astra.datastax.com" + db_adm = AstraDBDatabaseAdmin.from_api_endpoint(api_ep, token="t") + assert db_adm.get_database(namespace="M").api_endpoint == api_ep diff --git a/tests/idiomatic/unit/test_bulk_write_results.py b/tests/idiomatic/unit/test_bulk_write_results.py index cd6abc7e..519973a2 100644 --- a/tests/idiomatic/unit/test_bulk_write_results.py +++ b/tests/idiomatic/unit/test_bulk_write_results.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy.operations import reduce_bulk_write_results diff --git a/tests/idiomatic/unit/test_collection_options.py b/tests/idiomatic/unit/test_collection_options.py index 4c56315e..f6e09244 100644 --- a/tests/idiomatic/unit/test_collection_options.py +++ b/tests/idiomatic/unit/test_collection_options.py @@ -16,6 +16,8 @@ Unit tests for the validation/parsing of collection options """ +from __future__ import annotations + from typing import Any, Dict, List, Tuple import pytest diff --git a/tests/idiomatic/unit/test_collections_async.py b/tests/idiomatic/unit/test_collections_async.py index 5d61df45..362ab745 100644 --- a/tests/idiomatic/unit/test_collections_async.py +++ b/tests/idiomatic/unit/test_collections_async.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy import AsyncCollection, AsyncDatabase diff --git a/tests/idiomatic/unit/test_collections_sync.py b/tests/idiomatic/unit/test_collections_sync.py index 971a7239..9e1a11d1 100644 --- a/tests/idiomatic/unit/test_collections_sync.py +++ b/tests/idiomatic/unit/test_collections_sync.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy import Collection, Database diff --git a/tests/idiomatic/unit/test_databases_async.py b/tests/idiomatic/unit/test_databases_async.py index 0acd6ee0..3ef2dcbe 100644 --- a/tests/idiomatic/unit/test_databases_async.py +++ b/tests/idiomatic/unit/test_databases_async.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest -from astrapy import AsyncCollection, AsyncDatabase -from astrapy.core.defaults import DEFAULT_KEYSPACE_NAME +from astrapy import AsyncCollection, AsyncDatabase, DataAPIClient +from astrapy.constants import Environment +from astrapy.database import DEFAULT_ASTRA_DB_NAMESPACE from astrapy.exceptions import DevOpsAPIException -from ..conftest import ( - SECONDARY_NAMESPACE, - TEST_COLLECTION_INSTANCE_NAME, - DataAPICredentials, - DataAPICredentialsInfo, -) +from ..conftest import TEST_COLLECTION_INSTANCE_NAME, DataAPICredentials class TestDatabasesAsync: @@ -224,27 +222,6 @@ async def test_database_conversions_caller_mutableness_async( assert db1.to_sync().to_async() == db2 assert db1._copy() == db2 - @pytest.mark.skipif( - SECONDARY_NAMESPACE is None, reason="No secondary namespace provided" - ) - @pytest.mark.describe("test database namespace property, async") - async def test_database_namespace_async( - self, - data_api_credentials_kwargs: DataAPICredentials, - data_api_credentials_info: DataAPICredentialsInfo, - ) -> None: - db1 = AsyncDatabase( - **data_api_credentials_kwargs, - ) - assert db1.namespace == DEFAULT_KEYSPACE_NAME - - db2 = AsyncDatabase( - token=data_api_credentials_kwargs["token"], - api_endpoint=data_api_credentials_kwargs["api_endpoint"], - namespace=data_api_credentials_info["secondary_namespace"], - ) - assert db2.namespace == data_api_credentials_info["secondary_namespace"] - @pytest.mark.describe("test database id, async") async def test_database_id_async(self) -> None: db1 = AsyncDatabase( @@ -259,3 +236,49 @@ async def test_database_id_async(self) -> None: ) with pytest.raises(DevOpsAPIException): db2.id + + @pytest.mark.describe("test database default namespace per environment, async") + async def test_database_default_namespace_per_environment_async(self) -> None: + db_a_m = AsyncDatabase( + "ep", token="t", namespace="M", environment=Environment.PROD + ) + assert db_a_m.namespace == "M" + db_o_m = AsyncDatabase( + "ep", token="t", namespace="M", environment=Environment.OTHER + ) + assert db_o_m.namespace == "M" + db_a_n = AsyncDatabase("ep", token="t", environment=Environment.PROD) + assert db_a_n.namespace == DEFAULT_ASTRA_DB_NAMESPACE + db_o_n = AsyncDatabase("ep", token="t", environment=Environment.OTHER) + assert db_o_n.namespace is None + + @pytest.mark.describe( + "test database-from-client default namespace per environment, async" + ) + async def test_database_from_client_default_namespace_per_environment_async( + self, + ) -> None: + client_a = DataAPIClient(environment=Environment.PROD) + db_a_m = client_a.get_async_database("ep", region="r", namespace="M") + assert db_a_m.namespace == "M" + db_a_n = client_a.get_async_database("ep", region="r") + assert db_a_n.namespace == DEFAULT_ASTRA_DB_NAMESPACE + + client_o = DataAPIClient(environment=Environment.OTHER) + db_a_m = client_o.get_async_database("http://a", namespace="M") + assert db_a_m.namespace == "M" + db_a_n = client_o.get_async_database("http://a") + assert db_a_n.namespace is None + + @pytest.mark.describe( + "test database-from-dataapidbadmin default namespace per environment, async" + ) + async def test_database_from_dataapidbadmin_default_namespace_per_environment_async( + self, + ) -> None: + client = DataAPIClient(environment=Environment.OTHER) + db_admin = client.get_async_database("http://a").get_database_admin() + db_m = db_admin.get_async_database(namespace="M") + assert db_m.namespace == "M" + db_n = db_admin.get_async_database() + assert db_n.namespace is None diff --git a/tests/idiomatic/unit/test_databases_sync.py b/tests/idiomatic/unit/test_databases_sync.py index 7e36f397..f322fac9 100644 --- a/tests/idiomatic/unit/test_databases_sync.py +++ b/tests/idiomatic/unit/test_databases_sync.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest -from astrapy import Collection, Database -from astrapy.core.defaults import DEFAULT_KEYSPACE_NAME +from astrapy import Collection, DataAPIClient, Database +from astrapy.constants import Environment +from astrapy.database import DEFAULT_ASTRA_DB_NAMESPACE from astrapy.exceptions import DevOpsAPIException -from ..conftest import ( - SECONDARY_NAMESPACE, - TEST_COLLECTION_INSTANCE_NAME, - DataAPICredentials, - DataAPICredentialsInfo, -) +from ..conftest import TEST_COLLECTION_INSTANCE_NAME, DataAPICredentials class TestDatabasesSync: @@ -225,27 +223,6 @@ def test_database_conversions_caller_mutableness_sync( assert db1.to_async().to_sync() == db2 assert db1._copy() == db2 - @pytest.mark.skipif( - SECONDARY_NAMESPACE is None, reason="No secondary namespace provided" - ) - @pytest.mark.describe("test database namespace property, sync") - def test_database_namespace_sync( - self, - data_api_credentials_kwargs: DataAPICredentials, - data_api_credentials_info: DataAPICredentialsInfo, - ) -> None: - db1 = Database( - **data_api_credentials_kwargs, - ) - assert db1.namespace == DEFAULT_KEYSPACE_NAME - - db2 = Database( - token=data_api_credentials_kwargs["token"], - api_endpoint=data_api_credentials_kwargs["api_endpoint"], - namespace=data_api_credentials_info["secondary_namespace"], - ) - assert db2.namespace == data_api_credentials_info["secondary_namespace"] - @pytest.mark.describe("test database id, sync") def test_database_id_sync(self) -> None: db1 = Database( @@ -260,3 +237,43 @@ def test_database_id_sync(self) -> None: ) with pytest.raises(DevOpsAPIException): db2.id + + @pytest.mark.describe("test database default namespace per environment, sync") + def test_database_default_namespace_per_environment_sync(self) -> None: + db_a_m = Database("ep", token="t", namespace="M", environment=Environment.PROD) + assert db_a_m.namespace == "M" + db_o_m = Database("ep", token="t", namespace="M", environment=Environment.OTHER) + assert db_o_m.namespace == "M" + db_a_n = Database("ep", token="t", environment=Environment.PROD) + assert db_a_n.namespace == DEFAULT_ASTRA_DB_NAMESPACE + db_o_n = Database("ep", token="t", environment=Environment.OTHER) + assert db_o_n.namespace is None + + @pytest.mark.describe( + "test database-from-client default namespace per environment, sync" + ) + def test_database_from_client_default_namespace_per_environment_sync(self) -> None: + client_a = DataAPIClient(environment=Environment.PROD) + db_a_m = client_a.get_database("id", region="r", namespace="M") + assert db_a_m.namespace == "M" + db_a_n = client_a.get_database("id", region="r") + assert db_a_n.namespace == DEFAULT_ASTRA_DB_NAMESPACE + + client_o = DataAPIClient(environment=Environment.OTHER) + db_a_m = client_o.get_database("http://a", namespace="M") + assert db_a_m.namespace == "M" + db_a_n = client_o.get_database("http://a") + assert db_a_n.namespace is None + + @pytest.mark.describe( + "test database-from-dataapidbadmin default namespace per environment, sync" + ) + def test_database_from_dataapidbadmin_default_namespace_per_environment_sync( + self, + ) -> None: + client = DataAPIClient(environment=Environment.OTHER) + db_admin = client.get_database("http://a").get_database_admin() + db_m = db_admin.get_database(namespace="M") + assert db_m.namespace == "M" + db_n = db_admin.get_database() + assert db_n.namespace is None diff --git a/tests/idiomatic/unit/test_document_extractors.py b/tests/idiomatic/unit/test_document_extractors.py index 0e0d5f70..985f28e8 100644 --- a/tests/idiomatic/unit/test_document_extractors.py +++ b/tests/idiomatic/unit/test_document_extractors.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any, Dict, List import pytest diff --git a/tests/idiomatic/unit/test_exceptions.py b/tests/idiomatic/unit/test_exceptions.py index e87f723d..bb9b7449 100644 --- a/tests/idiomatic/unit/test_exceptions.py +++ b/tests/idiomatic/unit/test_exceptions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy.exceptions import ( diff --git a/tests/idiomatic/unit/test_ids.py b/tests/idiomatic/unit/test_ids.py index 9dd78047..57740ecc 100644 --- a/tests/idiomatic/unit/test_ids.py +++ b/tests/idiomatic/unit/test_ids.py @@ -16,6 +16,8 @@ Unit tests for the ObjectIds and UUIDn conversions, 'idiomatic' imports """ +from __future__ import annotations + import json import pytest diff --git a/tests/idiomatic/unit/test_imports.py b/tests/idiomatic/unit/test_imports.py index ac2ba8f7..3f0331c7 100644 --- a/tests/idiomatic/unit/test_imports.py +++ b/tests/idiomatic/unit/test_imports.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest @@ -125,6 +127,12 @@ def test_imports() -> None: CollectionVectorOptions, CollectionVectorServiceOptions, DatabaseInfo, + EmbeddingProvider, + EmbeddingProviderAuthentication, + EmbeddingProviderModel, + EmbeddingProviderParameter, + EmbeddingProviderToken, + FindEmbeddingProvidersResult, ) from astrapy.operations import ( # noqa: F401 AsyncBaseOperation, diff --git a/tests/idiomatic/unit/test_info.py b/tests/idiomatic/unit/test_info.py index 76509460..8dc54eb7 100644 --- a/tests/idiomatic/unit/test_info.py +++ b/tests/idiomatic/unit/test_info.py @@ -16,6 +16,8 @@ Unit tests for the parsing of API endpoints and related """ +from __future__ import annotations + import pytest from astrapy.admin import ParsedAPIEndpoint, parse_api_endpoint diff --git a/tests/idiomatic/unit/test_timeouts.py b/tests/idiomatic/unit/test_timeouts.py index d39baf19..793e2bbc 100644 --- a/tests/idiomatic/unit/test_timeouts.py +++ b/tests/idiomatic/unit/test_timeouts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import time import pytest diff --git a/tests/idiomatic/unit/test_token_providers.py b/tests/idiomatic/unit/test_token_providers.py index d368ce6a..8151e1b2 100644 --- a/tests/idiomatic/unit/test_token_providers.py +++ b/tests/idiomatic/unit/test_token_providers.py @@ -16,6 +16,8 @@ Unit tests for the token providers """ +from __future__ import annotations + import pytest from astrapy.authentication import ( diff --git a/tests/preprocess_env.py b/tests/preprocess_env.py index f15891fa..4dfc625f 100644 --- a/tests/preprocess_env.py +++ b/tests/preprocess_env.py @@ -18,6 +18,8 @@ Except for the vectorize information, which for the time being passes as os.environ. """ +from __future__ import annotations + import os import time from typing import Optional diff --git a/tests/vectorize_idiomatic/__init__.py b/tests/vectorize_idiomatic/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/vectorize_idiomatic/__init__.py +++ b/tests/vectorize_idiomatic/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/vectorize_idiomatic/conftest.py b/tests/vectorize_idiomatic/conftest.py index 3bdbd124..56eab027 100644 --- a/tests/vectorize_idiomatic/conftest.py +++ b/tests/vectorize_idiomatic/conftest.py @@ -16,6 +16,8 @@ Fixtures specific to testing on vectorize-ready Data API. """ +from __future__ import annotations + import os from typing import Any, Dict, Iterable diff --git a/tests/vectorize_idiomatic/integration/__init__.py b/tests/vectorize_idiomatic/integration/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/vectorize_idiomatic/integration/__init__.py +++ b/tests/vectorize_idiomatic/integration/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py b/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py index e3a74848..7ebd6de2 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any, Dict, List import pytest diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py b/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py index 1e2af97c..9589bc59 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any, Dict import pytest diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_ops_async.py b/tests/vectorize_idiomatic/integration/test_vectorize_ops_async.py new file mode 100644 index 00000000..f01064e9 --- /dev/null +++ b/tests/vectorize_idiomatic/integration/test_vectorize_ops_async.py @@ -0,0 +1,48 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +from astrapy import AsyncDatabase +from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult + + +class TestVectorizeOpsAsync: + @pytest.mark.describe("test of find_embedding_providers, async") + async def test_collection_methods_vectorize_async( + self, + async_database: AsyncDatabase, + ) -> None: + database_admin = async_database.get_database_admin() + ep_result = database_admin.find_embedding_providers() + + assert isinstance(ep_result, FindEmbeddingProvidersResult) + + assert all( + isinstance(emb_prov, EmbeddingProvider) + for emb_prov in ep_result.embedding_providers.values() + ) + + reconstructed = { + ep_name: EmbeddingProvider.from_dict(emb_prov.as_dict()) + for ep_name, emb_prov in ep_result.embedding_providers.items() + } + assert reconstructed == ep_result.embedding_providers + dict_mapping = { + ep_name: emb_prov.as_dict() + for ep_name, emb_prov in ep_result.embedding_providers.items() + } + assert dict_mapping == ep_result.raw_info["embeddingProviders"] # type: ignore[index] diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_ops_sync.py b/tests/vectorize_idiomatic/integration/test_vectorize_ops_sync.py new file mode 100644 index 00000000..e5a8d9cb --- /dev/null +++ b/tests/vectorize_idiomatic/integration/test_vectorize_ops_sync.py @@ -0,0 +1,48 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest + +from astrapy import Database +from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult + + +class TestVectorizeOpsSync: + @pytest.mark.describe("test of find_embedding_providers, sync") + def test_collection_methods_vectorize_sync( + self, + sync_database: Database, + ) -> None: + database_admin = sync_database.get_database_admin() + ep_result = database_admin.find_embedding_providers() + + assert isinstance(ep_result, FindEmbeddingProvidersResult) + + assert all( + isinstance(emb_prov, EmbeddingProvider) + for emb_prov in ep_result.embedding_providers.values() + ) + + reconstructed = { + ep_name: EmbeddingProvider.from_dict(emb_prov.as_dict()) + for ep_name, emb_prov in ep_result.embedding_providers.items() + } + assert reconstructed == ep_result.embedding_providers + dict_mapping = { + ep_name: emb_prov.as_dict() + for ep_name, emb_prov in ep_result.embedding_providers.items() + } + assert dict_mapping == ep_result.raw_info["embeddingProviders"] # type: ignore[index] diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py index 1f883f64..eaa227f6 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os from typing import Any, Dict, List, Union @@ -93,7 +95,7 @@ def test_vectorize_usage_auth_type_header_sync( # For the time being this is necessary on HEADER only embedding_api_key: Union[str, EmbeddingHeadersProvider] at_tokens = testable_vectorize_model["auth_type_tokens"] - at_token_lnames = {tk["accepted"].lower() for tk in at_tokens} + at_token_lnames = {tk.accepted.lower() for tk in at_tokens} if at_token_lnames == {"x-embedding-api-key"}: embedding_api_key = os.environ[ f"HEADER_EMBEDDING_API_KEY_{testable_vectorize_model['secret_tag']}" diff --git a/tests/vectorize_idiomatic/live_provider_info.py b/tests/vectorize_idiomatic/live_provider_info.py index f48882a6..c96df55f 100644 --- a/tests/vectorize_idiomatic/live_provider_info.py +++ b/tests/vectorize_idiomatic/live_provider_info.py @@ -12,53 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Dict from preprocess_env import ( ASTRA_DB_API_ENDPOINT, + ASTRA_DB_KEYSPACE, ASTRA_DB_TOKEN_PROVIDER, IS_ASTRA_DB, LOCAL_DATA_API_ENDPOINT, + LOCAL_DATA_API_KEYSPACE, LOCAL_DATA_API_TOKEN_PROVIDER, ) -from astrapy.api_commander import APICommander +from astrapy import DataAPIClient, Database +from astrapy.admin import parse_api_endpoint +from astrapy.constants import Environment +from astrapy.info import EmbeddingProvider -def live_provider_info() -> Dict[str, Any]: +def live_provider_info() -> Dict[str, EmbeddingProvider]: """ Query the API endpoint `findEmbeddingProviders` endpoint for the latest information. - This is where the preprocess_env variables are read to figure out whom to ask. + This utility function uses the environment variables it can find + to establish a target database to query. """ - response: Dict[str, Any] + database: Database if IS_ASTRA_DB: - if ASTRA_DB_TOKEN_PROVIDER is None: - raise ValueError("No token provider for Astra DB") - path = "api/json/v1" - headers_a: Dict[str, Optional[str]] = { - "Token": ASTRA_DB_TOKEN_PROVIDER.get_token(), - } - cmd = APICommander( - api_endpoint=ASTRA_DB_API_ENDPOINT or "", - path=path, - headers=headers_a, + parsed = parse_api_endpoint(ASTRA_DB_API_ENDPOINT) + if parsed is None: + raise ValueError( + "Cannot parse the Astra DB API Endpoint '{ASTRA_DB_API_ENDPOINT}'" + ) + client = DataAPIClient(environment=parsed.environment) + database = client.get_database( + ASTRA_DB_API_ENDPOINT, + token=ASTRA_DB_TOKEN_PROVIDER, + namespace=ASTRA_DB_KEYSPACE, ) - response = cmd.request(payload={"findEmbeddingProviders": {}}) else: - path = "v1" - if LOCAL_DATA_API_TOKEN_PROVIDER is None: - raise ValueError("No token provider for Local Data API") - headers_l: Dict[str, Optional[str]] = { - "Token": LOCAL_DATA_API_TOKEN_PROVIDER.get_token(), - } - cmd = APICommander( - api_endpoint=LOCAL_DATA_API_ENDPOINT or "", - path=path, - headers=headers_l, + client = DataAPIClient(environment=Environment.OTHER) + database = client.get_database( + LOCAL_DATA_API_ENDPOINT, + token=LOCAL_DATA_API_TOKEN_PROVIDER, + namespace=LOCAL_DATA_API_KEYSPACE, ) - response = cmd.request(payload={"findEmbeddingProviders": {}}) - return response + database_admin = database.get_database_admin() + response = database_admin.find_embedding_providers() + return response.embedding_providers diff --git a/tests/vectorize_idiomatic/query_providers.py b/tests/vectorize_idiomatic/query_providers.py index d85c5e17..6b1f04d3 100644 --- a/tests/vectorize_idiomatic/query_providers.py +++ b/tests/vectorize_idiomatic/query_providers.py @@ -12,28 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import os import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from typing import Any, Dict +from typing import Dict from live_provider_info import live_provider_info +from astrapy.info import EmbeddingProvider, EmbeddingProviderParameter + -def desc_param(param_data: Dict[str, Any]) -> str: - if param_data["type"].lower() == "string": +def desc_param(param_data: EmbeddingProviderParameter) -> str: + if param_data.parameter_type.lower() == "string": return "str" - elif param_data["type"].lower() == "number": - validation = param_data.get("validation", {}) + elif param_data.parameter_type.lower() == "number": + validation = param_data.validation if "numericRange" in validation: validation_nr = validation["numericRange"] assert isinstance(validation_nr, list) and len(validation_nr) == 2 range_desc = f"[{validation_nr[0]} : {validation_nr[1]}]" - if "defaultValue" in param_data: - range_desc2 = f"{range_desc} (default={param_data['defaultValue']})" + if param_data.default_value is not None: + range_desc2 = f"{range_desc} (default={param_data.default_value})" else: range_desc2 = range_desc return f"number, {range_desc2}" @@ -45,56 +49,50 @@ def desc_param(param_data: Dict[str, Any]) -> str: raise ValueError( f"Unknown number validation spec: '{json.dumps(validation)}'" ) - elif param_data["type"].lower() == "boolean": + elif param_data.parameter_type.lower() == "boolean": return "bool" else: raise NotImplementedError if __name__ == "__main__": - response: Dict[str, Any] - if "l" in sys.argv[1:]: - response = json.load(open("_providers.json")) - else: - response = live_provider_info() - json.dump(response, open("_providers.json", "w"), indent=2, sort_keys=True) + providers: Dict[str, EmbeddingProvider] = live_provider_info() + providers_json = {ep_name: ep.as_dict() for ep_name, ep in providers.items()} + json.dump(providers_json, open("_providers.json", "w"), indent=2, sort_keys=True) - provider_map = response["status"]["embeddingProviders"] - for provider, provider_data in sorted(provider_map.items()): - print(f"{provider} ({len(provider_data['models'])} models)") + for provider, provider_data in sorted(providers.items()): + print(f"{provider} ({len(provider_data.models)} models)") print(" auth:") for auth_type, auth_data in sorted( - provider_data["supportedAuthentication"].items() + provider_data.supported_authentication.items() ): - if auth_data["enabled"]: - tokens = ", ".join( - f"'{tok['accepted']}'" for tok in auth_data["tokens"] - ) + if auth_data.enabled: + tokens = ", ".join(f"'{tok.accepted}'" for tok in auth_data.tokens) print(f" {auth_type} ({tokens})") - if provider_data.get("parameters"): + if provider_data.parameters: print(" parameters") - for param_data in provider_data["parameters"]: - param_name = param_data["name"] - if param_data["required"]: + for param_data in provider_data.parameters: + param_name = param_data.name + if param_data.required: param_display_name = param_name else: param_display_name = f"({param_name})" param_desc = desc_param(param_data) print(f" - {param_display_name}: {param_desc}") print(" models:") - for model_data in sorted(provider_data["models"], key=lambda pro: pro["name"]): - model_name = model_data["name"] - if model_data["vectorDimension"] is not None: - assert model_data["vectorDimension"] > 0 - model_dim_desc = f" (D = {model_data['vectorDimension']})" + for model_data in sorted(provider_data.models, key=lambda pro: pro.name): + model_name = model_data.name + if model_data.vector_dimension is not None: + assert model_data.vector_dimension > 0 + model_dim_desc = f" (D = {model_data.vector_dimension})" else: model_dim_desc = "" if True: print(f" {model_name}{model_dim_desc}") - if model_data.get("parameters"): - for param_data in model_data["parameters"]: - param_name = param_data["name"] - if param_data["required"]: + if model_data.parameters: + for param_data in model_data.parameters: + param_name = param_data.name + if param_data.required: param_display_name = param_name else: param_display_name = f"({param_name})" diff --git a/tests/vectorize_idiomatic/unit/__init__.py b/tests/vectorize_idiomatic/unit/__init__.py index 2c9ca172..84497ed1 100644 --- a/tests/vectorize_idiomatic/unit/__init__.py +++ b/tests/vectorize_idiomatic/unit/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations diff --git a/tests/vectorize_idiomatic/unit/test_collectionvectorserviceoptions.py b/tests/vectorize_idiomatic/unit/test_collectionvectorserviceoptions.py index f97495a5..046d29ea 100644 --- a/tests/vectorize_idiomatic/unit/test_collectionvectorserviceoptions.py +++ b/tests/vectorize_idiomatic/unit/test_collectionvectorserviceoptions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy.info import CollectionVectorServiceOptions diff --git a/tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py b/tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py index c650cfb5..e77e7a86 100644 --- a/tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py +++ b/tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import pytest from astrapy.authentication import ( @@ -19,22 +21,22 @@ EMBEDDING_HEADER_AWS_ACCESS_ID, EMBEDDING_HEADER_AWS_SECRET_ID, AWSEmbeddingHeadersProvider, - StaticEmbeddingHeadersProvider, + EmbeddingAPIKeyHeaderProvider, coerce_embedding_headers_provider, ) class TestEmbeddingHeadersProvider: - @pytest.mark.describe("test of headers from StaticEmbeddingHeadersProvider") + @pytest.mark.describe("test of headers from EmbeddingAPIKeyHeaderProvider") def test_embeddingheadersprovider_static(self) -> None: - ehp = StaticEmbeddingHeadersProvider("x") + ehp = EmbeddingAPIKeyHeaderProvider("x") assert {k.lower(): v for k, v in ehp.get_headers().items()} == { EMBEDDING_HEADER_API_KEY.lower(): "x" } - @pytest.mark.describe("test of headers from empty StaticEmbeddingHeadersProvider") + @pytest.mark.describe("test of headers from empty EmbeddingAPIKeyHeaderProvider") def test_embeddingheadersprovider_null(self) -> None: - ehp = StaticEmbeddingHeadersProvider(None) + ehp = EmbeddingAPIKeyHeaderProvider(None) assert ehp.get_headers() == {} @pytest.mark.describe("test of headers from AWSEmbeddingHeadersProvider") @@ -53,8 +55,8 @@ def test_embeddingheadersprovider_aws(self) -> None: @pytest.mark.describe("test of embedding headers provider coercion") def test_embeddingheadersprovider_coercion(self) -> None: """This doubles as equality test.""" - ehp_s = StaticEmbeddingHeadersProvider("x") - ehp_n = StaticEmbeddingHeadersProvider(None) + ehp_s = EmbeddingAPIKeyHeaderProvider("x") + ehp_n = EmbeddingAPIKeyHeaderProvider(None) ehp_a = AWSEmbeddingHeadersProvider( embedding_access_id="x", embedding_secret_id="y", diff --git a/tests/vectorize_idiomatic/vectorize_models.py b/tests/vectorize_idiomatic/vectorize_models.py index f251da74..6a55fe1e 100644 --- a/tests/vectorize_idiomatic/vectorize_models.py +++ b/tests/vectorize_idiomatic/vectorize_models.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys from typing import Any, Dict, Iterable, List, Tuple @@ -23,7 +25,7 @@ EMBEDDING_HEADER_AWS_ACCESS_ID, EMBEDDING_HEADER_AWS_SECRET_ID, ) -from astrapy.info import CollectionVectorServiceOptions +from astrapy.info import CollectionVectorServiceOptions, EmbeddingProviderParameter from .live_provider_info import live_provider_info @@ -172,14 +174,14 @@ def live_test_models() -> Iterable[Dict[str, Any]]: - def _from_validation(pspec: Dict[str, Any]) -> int: - assert pspec["type"] == "number" - if "numericRange" in pspec["validation"]: - m0: int = pspec["validation"]["numericRange"][0] - m1: int = pspec["validation"]["numericRange"][1] + def _from_validation(pspec: EmbeddingProviderParameter) -> int: + assert pspec.parameter_type == "number" + if "numericRange" in pspec.validation: + m0: int = pspec.validation["numericRange"][0] + m1: int = pspec.validation["numericRange"][1] return (m0 + m1) // 2 - elif "options" in pspec["validation"]: - options: List[int] = pspec["validation"]["options"] + elif "options" in pspec.validation: + options: List[int] = pspec.validation["options"] if len(options) > 1: return options[1] else: @@ -194,21 +196,19 @@ def _collapse(longt: str) -> str: return f"{longt[:30]}_{longt[-5:]}" # generate the full list of models based on the live provider endpoint - live_info = live_provider_info()["status"]["embeddingProviders"] - for provider_name, provider_desc in sorted(live_info.items()): - for model in provider_desc["models"]: + providers = live_provider_info() + for provider_name, provider_desc in sorted(providers.items()): + for model in provider_desc.models: for auth_type_name, auth_type_desc in sorted( - provider_desc["supportedAuthentication"].items() + provider_desc.supported_authentication.items() ): - if auth_type_desc["enabled"]: + if auth_type_desc.enabled: # test assumptions on auth type if auth_type_name == "NONE": - assert auth_type_desc["tokens"] == [] + assert auth_type_desc.tokens == [] elif auth_type_name == "HEADER": header_names_lower = tuple( - sorted( - t["accepted"].lower() for t in auth_type_desc["tokens"] - ) + sorted(t.accepted.lower() for t in auth_type_desc.tokens) ) assert header_names_lower in { (EMBEDDING_HEADER_API_KEY.lower(),), @@ -219,7 +219,7 @@ def _collapse(longt: str) -> str: } elif auth_type_name == "SHARED_SECRET": authkey_names = tuple( - sorted(t["accepted"] for t in auth_type_desc["tokens"]) + sorted(t.accepted for t in auth_type_desc.tokens) ) assert authkey_names in { ("providerKey",), @@ -229,85 +229,79 @@ def _collapse(longt: str) -> str: raise ValueError("Unknown auth type") # params - collated_params = provider_desc.get("parameters", []) + model.get( - "parameters", [] - ) + collated_params = provider_desc.parameters + model.parameters all_nond_params = [ param for param in collated_params - if param["name"] != "vectorDimension" + if param.name != "vectorDimension" ] required_nond_params = { - param["name"] for param in all_nond_params if param["required"] + param.name for param in all_nond_params if param.required } optional_nond_params = { - param["name"] - for param in all_nond_params - if not param["required"] + param.name for param in all_nond_params if not param.required } # d_params = [ param for param in collated_params - if param["name"] == "vectorDimension" + if param.name == "vectorDimension" ] if d_params: d_param = d_params[0] - if "defaultValue" in d_param: - if (provider_name, model["name"]) in FORCE_DIMENSION_MAP: + if d_param.default_value is not None: + if (provider_name, model.name) in FORCE_DIMENSION_MAP: optional_dimension = False dimension = FORCE_DIMENSION_MAP[ - (provider_name, model["name"]) + (provider_name, model.name) ] else: optional_dimension = True - assert model["vectorDimension"] is None + assert model.vector_dimension is None dimension = _from_validation(d_param) else: optional_dimension = False - assert model["vectorDimension"] is None + assert model.vector_dimension is None dimension = _from_validation(d_param) else: optional_dimension = False - assert model["vectorDimension"] is not None - assert model["vectorDimension"] > 0 - dimension = model["vectorDimension"] + assert model.vector_dimension is not None + assert model.vector_dimension > 0 + dimension = model.vector_dimension model_parameters = { param_name: PARAMETER_VALUE_MAP[ - (provider_name, model["name"], param_name) + (provider_name, model.name, param_name) ] for param_name in required_nond_params } optional_model_parameters = { param_name: PARAMETER_VALUE_MAP[ - (provider_name, model["name"], param_name) + (provider_name, model.name, param_name) ] for param_name in optional_nond_params } if optional_dimension or optional_nond_params != set(): # we issue a minimal-params version - model_tag_0 = ( - f"{provider_name}/{model['name']}/{auth_type_name}/0" - ) + model_tag_0 = f"{provider_name}/{model.name}/{auth_type_name}/0" this_minimal_model = { "model_tag": model_tag_0, "simple_tag": _collapse( "".join(c for c in model_tag_0 if c in alphanum) ), "auth_type_name": auth_type_name, - "auth_type_tokens": auth_type_desc["tokens"], + "auth_type_tokens": auth_type_desc.tokens, "secret_tag": SECRET_NAME_ROOT_MAP[provider_name], "test_assets": TEST_ASSETS_MAP.get( - (provider_name, model["name"]), DEFAULT_TEST_ASSETS + (provider_name, model.name), DEFAULT_TEST_ASSETS ), "use_insert_one": USE_INSERT_ONE_MAP.get( - (provider_name, model["name"]), False + (provider_name, model.name), False ), "service_options": CollectionVectorServiceOptions( provider=provider_name, - model_name=model["name"], + model_name=model.name, parameters=model_parameters, ), } @@ -321,20 +315,18 @@ def _collapse(longt: str) -> str: ): root_model = { "auth_type_name": auth_type_name, - "auth_type_tokens": auth_type_desc["tokens"], + "auth_type_tokens": auth_type_desc.tokens, "dimension": dimension, "secret_tag": SECRET_NAME_ROOT_MAP[provider_name], "test_assets": TEST_ASSETS_MAP.get( - (provider_name, model["name"]), DEFAULT_TEST_ASSETS + (provider_name, model.name), DEFAULT_TEST_ASSETS ), "use_insert_one": USE_INSERT_ONE_MAP.get( - (provider_name, model["name"]), False + (provider_name, model.name), False ), } - model_tag_f = ( - f"{provider_name}/{model['name']}/{auth_type_name}/f" - ) + model_tag_f = f"{provider_name}/{model.name}/{auth_type_name}/f" this_model = { "model_tag": model_tag_f, @@ -343,7 +335,7 @@ def _collapse(longt: str) -> str: ), "service_options": CollectionVectorServiceOptions( provider=provider_name, - model_name=model["name"], + model_name=model.name, parameters={ **model_parameters, **optional_model_parameters,