Skip to content

Commit

Permalink
WIP: feat: invalidate cache on bad connections
Browse files Browse the repository at this point in the history
  • Loading branch information
rhatgadkar-goog committed Nov 12, 2024
1 parent 22ef77c commit 01cc3aa
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 4 deletions.
21 changes: 19 additions & 2 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,14 @@ async def connect(
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
try:
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from AlloyDB Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(instance_uri)
raise
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")

# callable to be used for auto IAM authn
Expand All @@ -202,6 +208,17 @@ def get_authentication_token() -> str:
await cache.force_refresh()
raise

async def _remove_cached(self, instance_uri: str) -> None:
"""Stops all background refreshes and deletes the connection
info cache from the map of caches.
"""
logger.debug(
f"['{instance_uri}']: Removing connection info from cache"
)
# remove cache from stored caches and close it
cache = self._cache.pop(instance_uri)
await cache.close()

async def __aenter__(self) -> Any:
"""Enter async context manager by returning Connector object"""
return self
Expand Down
21 changes: 19 additions & 2 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,14 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
try:
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from AlloyDB Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(instance_uri)
raise
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")

# synchronous drivers are blocking and run using executor
Expand Down Expand Up @@ -334,6 +340,17 @@ def metadata_exchange(

return sock

async def _remove_cached(self, instance_uri: str) -> None:
"""Stops all background refreshes and deletes the connection
info cache from the map of caches.
"""
logger.debug(
f"['{instance_uri}']: Removing connection info from cache"
)
# remove cache from stored caches and close it
cache = self._cache.pop(instance_uri)
await cache.close()

def __enter__(self) -> "Connector":
"""Enter context manager by returning Connector object"""
return self
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
from aiohttp import ClientResponseError
from datetime import datetime
from datetime import timedelta
from datetime import timezone
Expand Down Expand Up @@ -206,7 +207,14 @@ def __init__(
self._user_agent = f"test-user-agent+{driver}"
self._credentials = FakeCredentials()

i = FakeInstance()
# The instances that currently exist and the client can send API requests to.
self.existing_instances = [f"projects/{i.project}/locations/{i.region}/clusters/{i.cluster}/instances/{i.name}"]

async def _get_metadata(self, *args: Any, **kwargs: Any) -> str:
instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}"
if instance_uri not in self.existing_instances:
raise ClientResponseError(None, 404)
return self.instance.ip_addrs

async def _get_client_certificate(
Expand All @@ -216,6 +224,9 @@ async def _get_client_certificate(
cluster: str,
pub_key: str,
) -> Tuple[str, List[str]]:
instance_uri = f"projects/{self.instance.project}/locations/{self.instance.region}/clusters/{self.instance.cluster}/instances/{self.instance.name}"
if instance_uri not in self.existing_instances:
raise ClientResponseError(None, 404)
root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs()
# encode public key to bytes
pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key(
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
from mock import patch
from mocks import FakeAlloyDBClient
from mocks import FakeCredentials
from mocks import FakeInstance
from aiohttp import ClientResponseError
import pytest

from google.cloud.alloydb.connector import Connector
from google.cloud.alloydb.connector import IPTypes
from google.cloud.alloydb.connector.instance import RefreshAheadCache


def test_Connector_init(credentials: FakeCredentials) -> None:
Expand Down Expand Up @@ -203,3 +206,54 @@ def test_Connector_close_called_multiple_times(credentials: FakeCredentials) ->
assert connector._thread.is_alive() is False
# call connector.close a second time
connector.close()


@pytest.mark.asyncio
async def test_Connector_remove_cached_bad_instance(
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
) -> None:
"""When a Connector attempts to retrieve connection info for a
non-existent instance, it should delete the instance from
the cache and ensure no background refresh happens (which would be
wasted cycles).
"""
instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance"
with Connector(credentials) as connector:
connector._client = FakeAlloyDBClient(instance = FakeInstance(name = "bad-test-instance"))
# patch db connection creation
with patch("google.cloud.alloydb.connector.pg8000.connect") as mock_connect:
mock_connect.return_value = True
cache = RefreshAheadCache(instance_uri, fake_client, connector._keys)
connector._cache[instance_uri] = cache
with pytest.raises(ClientResponseError):
await connector.connect_async(instance_uri, "pg8000")
assert instance_uri not in connector._cache


# def test_Connector_remove_cached_no_ip_type(
# fake_credentials: FakeCredentials, fake_client: FakeAlloyDBClient
# ) -> None:
# """When a Connector attempts to connect and preferred IP type is not present,
# it should delete the instance from the cache and ensure no background refresh
# happens (which would be wasted cycles).
# """
# # set instance to only have public IP
# fake_client.instance.ip_addrs = {"PRIMARY": "127.0.0.1"}
# async with Connector(
# credentials=fake_credentials, loop=asyncio.get_running_loop()
# ) as connector:
# conn_name = "test-project:test-region:test-instance"
# # populate cache
# cache = RefreshAheadCache(conn_name, fake_client_sync, connector._keys)
# connector._cache[conn_name] = cache
# # test instance does not have Private IP, thus should invalidate cache
# with pytest.raises(CloudSQLIPTypeError):
# await connector.connect_async(
# conn_name,
# "pg8000",
# user="my-user",
# password="my-pass",
# ip_type="private",
# )
# # check that cache has been removed from dict
# assert conn_name not in connector._cache

0 comments on commit 01cc3aa

Please sign in to comment.