Skip to content

Commit

Permalink
Support custom cache for OAuth2 tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Aug 19, 2022
1 parent aee6064 commit c04d408
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 10 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,38 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins
)
```

A custom caching implementation can be provided by creating a class implementing the `trino.auth.OAuth2TokenCache` abstract class and adding it as in `OAuth2Authentication(cache=my_custom_cache_impl)`. The custom caching implementation enables usage in multi-user environments (notebooks, web applications) in combination with a custom `redirect_auth_url_handler` as explained above.

```python
from typing import Optional

from trino.auth import OAuth2Authentication, OAuth2TokenCache
from trino.dbapi import connect


class MyCustomCacheImpl(OAuth2TokenCache):
def get_token_from_cache(self, host: str) -> Optional[str]:
# Retrieve your cached token from a distributed system
# and return it
pass

def store_token_to_cache(self, host: str, token: str) -> None:
# Store your cached token in a distributed system
pass


def my_custom_redirect_handler(url: str) -> None:
# ensure the url is opened by the user that should perform the authentication
pass

conn = connect(
user="<username>",
auth=OAuth2Authentication(cache=MyCustomCacheImpl(), redirect_auth_url_handler=my_custom_redirect_handler),
http_scheme="https",
...
)
```

### Certificate Authentication

`CertificateAuthentication` class can be used to connect to Trino cluster configured with [certificate based authentication](https://trino.io/docs/current/security/certificate.html). `CertificateAuthentication` requires paths to a valid client certificate and private key.
Expand Down
51 changes: 49 additions & 2 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
import threading
import uuid
from unittest.mock import patch
from unittest.mock import patch, MagicMock

import httpretty
from httpretty import httprettified
Expand All @@ -20,7 +20,7 @@
from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \
GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS
from trino import constants
from trino.auth import OAuth2Authentication
from trino.auth import OAuth2Authentication, OAuth2TokenCache
from trino.dbapi import connect


Expand Down Expand Up @@ -107,6 +107,53 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data):
assert len(_get_token_requests(challenge_id)) == 2


@httprettified
def test_custom_token_cache_is_invoked(sample_post_response_data):
host = "coordinator"
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"

post_statement_callback = PostStatementCallback(redirect_server, token_server, [token], sample_post_response_data)

# bind post statement
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)

redirect_handler = RedirectHandler()

custom_cache = MagicMock(OAuth2TokenCache)
custom_cache.get_token_from_cache = MagicMock(side_effect=[None, token, token, token])
custom_cache.store_token_to_cache = MagicMock()

with connect(
host,
user="test",
auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler, cache=custom_cache),
http_scheme=constants.HTTPS
) as conn:
conn.cursor().execute("SELECT 1")
conn.cursor().execute("SELECT 2")
conn.cursor().execute("SELECT 3")

assert len(_get_token_requests(challenge_id)) == 1
custom_cache.get_token_from_cache.assert_called_with(host)
assert custom_cache.get_token_from_cache.call_count == 4
custom_cache.store_token_to_cache.assert_called_with(host, token)
assert custom_cache.store_token_to_cache.call_count == 1


@httprettified
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data):
token = str(uuid.uuid4())
Expand Down
26 changes: 18 additions & 8 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __call__(self, url: str):
handler(url)


class _OAuth2TokenCache(metaclass=abc.ABCMeta):
class OAuth2TokenCache(metaclass=abc.ABCMeta):
"""
Abstract class for OAuth token cache, inherit from this class to implement your own token cache.
"""
Expand All @@ -216,7 +216,7 @@ def store_token_to_cache(self, host: str, token: str) -> None:
pass


class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
class _OAuth2TokenInMemoryCache(OAuth2TokenCache):
"""
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
"""
Expand All @@ -231,7 +231,7 @@ def store_token_to_cache(self, host: str, token: str) -> None:
self._cache[host] = token


class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
class _OAuth2KeyRingTokenCache(OAuth2TokenCache):
"""
Keyring Token Cache implementation
"""
Expand Down Expand Up @@ -272,10 +272,9 @@ class _OAuth2TokenBearer(AuthBase):
MAX_OAUTH_ATTEMPTS = 5
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)

def __init__(self, redirect_auth_url_handler: Callable[[str], None]):
def __init__(self, redirect_auth_url_handler: Callable[[str], None], custom_cache: Optional[OAuth2TokenCache]):
self._redirect_auth_url = redirect_auth_url_handler
keyring_cache = _OAuth2KeyRingTokenCache()
self._token_cache = keyring_cache if keyring_cache.is_keyring_available() else _OAuth2TokenInMemoryCache()
self._token_cache = self._setup_cache(custom_cache)
self._token_lock = threading.Lock()
self._inside_oauth_attempt_lock = threading.Lock()
self._inside_oauth_attempt_blocker = threading.Event()
Expand All @@ -291,6 +290,17 @@ def __call__(self, r):

return r

def _setup_cache(self, custom_cache):
if custom_cache is not None:
if not isinstance(custom_cache, OAuth2TokenCache):
raise exceptions.TrinoAuthError("Custom cache does not implement `trino.auth.OAuth2TokenCache` "
"interface")
return custom_cache
keyring_cache = _OAuth2KeyRingTokenCache()
if keyring_cache.is_keyring_available():
return keyring_cache
return _OAuth2TokenInMemoryCache()

def _authenticate(self, response, **kwargs):
if not 400 <= response.status_code < 500:
return response
Expand Down Expand Up @@ -396,9 +406,9 @@ class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler=CompositeRedirectHandler([
WebBrowserRedirectHandler(),
ConsoleRedirectHandler()
])):
]), cache: Optional[OAuth2TokenCache] = None):
self._redirect_auth_url = redirect_auth_url_handler
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url)
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url, custom_cache=cache)

def set_http_session(self, http_session):
http_session.auth = self._bearer
Expand Down

0 comments on commit c04d408

Please sign in to comment.