-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support for asyncpg (#199)
Co-authored-by: Jack Wotherspoon <[email protected]>
- Loading branch information
1 parent
3dd1b05
commit d14617b
Showing
8 changed files
with
610 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
from types import TracebackType | ||
from typing import Any, Dict, Optional, Type, TYPE_CHECKING | ||
|
||
from google.auth import default | ||
from google.auth.credentials import with_scopes_if_required | ||
|
||
import google.cloud.alloydb.connector.asyncpg as asyncpg | ||
from google.cloud.alloydb.connector.client import AlloyDBClient | ||
from google.cloud.alloydb.connector.instance import Instance | ||
from google.cloud.alloydb.connector.utils import generate_keys | ||
|
||
if TYPE_CHECKING: | ||
from google.auth.credentials import Credentials | ||
|
||
|
||
class AsyncConnector: | ||
"""A class to configure and create connections to Cloud SQL instances | ||
asynchronously. | ||
Args: | ||
credentials (google.auth.credentials.Credentials): | ||
A credentials object created from the google-auth Python library. | ||
If not specified, Application Default Credentials are used. | ||
quota_project (str): The Project ID for an existing Google Cloud | ||
project. The project specified is used for quota and | ||
billing purposes. | ||
Defaults to None, picking up project from environment. | ||
alloydb_api_endpoint (str): Base URL to use when calling | ||
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". | ||
""" | ||
|
||
def __init__( | ||
self, | ||
credentials: Optional[Credentials] = None, | ||
quota_project: Optional[str] = None, | ||
alloydb_api_endpoint: str = "https://alloydb.googleapis.com", | ||
) -> None: | ||
self._instances: Dict[str, Instance] = {} | ||
# initialize default params | ||
self._quota_project = quota_project | ||
self._alloydb_api_endpoint = alloydb_api_endpoint | ||
# initialize credentials | ||
scopes = ["https://www.googleapis.com/auth/cloud-platform"] | ||
if credentials: | ||
self._credentials = with_scopes_if_required(credentials, scopes=scopes) | ||
# otherwise use application default credentials | ||
else: | ||
self._credentials, _ = default(scopes=scopes) | ||
self._keys = asyncio.create_task(generate_keys()) | ||
self._client: Optional[AlloyDBClient] = None | ||
|
||
async def connect( | ||
self, | ||
instance_uri: str, | ||
driver: str, | ||
**kwargs: Any, | ||
) -> Any: | ||
""" | ||
Asynchronously prepares and returns a database connection object. | ||
Starts tasks to refresh the certificates and get | ||
AlloyDB instance IP address. Creates a secure TLS connection | ||
to establish connection to AlloyDB instance. | ||
Args: | ||
instance_uri (str): The instance URI of the AlloyDB instance. | ||
ex. projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE> | ||
driver (str): A string representing the database driver to connect | ||
with. Supported drivers are asyncpg. | ||
**kwargs: Pass in any database driver-specific arguments needed | ||
to fine tune connection. | ||
Returns: | ||
connection: A DBAPI connection to the specified AlloyDB instance. | ||
""" | ||
if self._client is None: | ||
# lazy init client as it has to be initialized in async context | ||
self._client = AlloyDBClient( | ||
self._alloydb_api_endpoint, | ||
self._quota_project, | ||
self._credentials, | ||
) | ||
|
||
# use existing connection info if possible | ||
if instance_uri in self._instances: | ||
instance = self._instances[instance_uri] | ||
else: | ||
instance = Instance(instance_uri, self._client, self._keys) | ||
self._instances[instance_uri] = instance | ||
|
||
connect_func = { | ||
"asyncpg": asyncpg.connect, | ||
} | ||
# only accept supported database drivers | ||
try: | ||
connector = connect_func[driver] | ||
except KeyError: | ||
raise ValueError(f"Driver '{driver}' is not a supported database driver.") | ||
|
||
# Host and ssl options come from the certificates and instance IP | ||
# address so we don't want the user to specify them. | ||
kwargs.pop("host", None) | ||
kwargs.pop("ssl", None) | ||
kwargs.pop("port", None) | ||
|
||
# get connection info for AlloyDB instance | ||
ip_address, context = await instance.connection_info() | ||
|
||
try: | ||
return await connector(ip_address, context, **kwargs) | ||
except Exception: | ||
# we attempt a force refresh, then throw the error | ||
await instance.force_refresh() | ||
raise | ||
|
||
async def __aenter__(self) -> Any: | ||
"""Enter async context manager by returning Connector object""" | ||
return self | ||
|
||
async def __aexit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_val: Optional[BaseException], | ||
exc_tb: Optional[TracebackType], | ||
) -> None: | ||
"""Exit async context manager by closing Connector""" | ||
await self.close() | ||
|
||
async def close(self) -> None: | ||
"""Helper function to cancel Instances' tasks | ||
and close client.""" | ||
await asyncio.gather( | ||
*[instance.close() for instance in self._instances.values()] | ||
) | ||
if self._client: | ||
await self._client.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import ssl | ||
from typing import Any, TYPE_CHECKING | ||
|
||
SERVER_PROXY_PORT = 5433 | ||
|
||
if TYPE_CHECKING: | ||
import asyncpg | ||
|
||
|
||
async def connect( | ||
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any | ||
) -> "asyncpg.Connection": | ||
"""Helper function to create an asyncpg DB-API connection object. | ||
:type ip_address: str | ||
:param ip_address: A string containing an IP address for the AlloyDB | ||
instance. | ||
:type ctx: ssl.SSLContext | ||
:param ctx: An SSLContext object created from the AlloyDB server CA | ||
cert and ephemeral cert. | ||
:type kwargs: Any | ||
:param kwargs: Keyword arguments for establishing asyncpg connection | ||
object to AlloyDB instance. | ||
:rtype: asyncpg.Connection | ||
:returns: An asyncpg.Connection object to an AlloyDB instance. | ||
""" | ||
try: | ||
import asyncpg | ||
except ImportError: | ||
raise ImportError( | ||
'Unable to import module "asyncpg." Please install and try again.' | ||
) | ||
user = kwargs.pop("user") | ||
db = kwargs.pop("db") | ||
passwd = kwargs.pop("password") | ||
|
||
return await asyncpg.connect( | ||
user=user, | ||
database=db, | ||
password=passwd, | ||
host=ip_address, | ||
port=SERVER_PROXY_PORT, | ||
ssl=ctx, | ||
direct_tls=True, | ||
**kwargs, | ||
) |
Oops, something went wrong.