Skip to content

Commit

Permalink
Add support for specifying 'RSAPublicKey' instance instead of raw byt…
Browse files Browse the repository at this point in the history
…es (#1477)

* Add support for specifying 'RSAPublicKey' instance instead of raw bytes

This can be used to externalize the JWT encoding process.

* Add test for private key abstraction layer

* Add 'isinstance' check to make sure private key has an expected type

* Revert method signature change

Note that while this method does not require a private key, the change is
inconsequential because we're anyway expecting something that implements
a private key at the class level (either bytes or an abstract implementation)

* Be more specific in type error message

* Add failing test for non-bytes, non-RSAPrivateKey value

* Fix linting issues

* add changelog

* Move cases which are now handled by type testing over to unit test

---------

Co-authored-by: sfc-gh-sfan <[email protected]>
  • Loading branch information
malthe and sfc-gh-sfan authored Aug 4, 2023
1 parent 1380e41 commit db8e265
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 24 deletions.
4 changes: 4 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne

# Release Notes

- v3.1.1(TBD)

- Support `RSAPublicKey` when constructing `AuthByKeyPair` in addition to raw bytes.

- v3.1.0(July 31,2023)

- Added a feature that lets you add connection definitions to the `connections.toml` configuration file. A connection definition refers to a collection of connection parameters, for example, if you wanted to define a connection named `prod``:
Expand Down
50 changes: 29 additions & 21 deletions src/snowflake/connector/auth/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ class AuthByKeyPair(AuthByPlugin):

def __init__(
self,
private_key: bytes,
private_key: bytes | RSAPrivateKey,
lifetime_in_seconds: int = LIFETIME,
) -> None:
"""Inits AuthByKeyPair class with private key.
Args:
private_key: a byte array of der formats of private key
private_key: a byte array of der formats of private key, or an
object that implements the `RSAPrivateKey` interface.
lifetime_in_seconds: number of seconds the JWT token will be valid
"""
super().__init__()
self._private_key: bytes | None = private_key
self._private_key: bytes | RSAPrivateKey | None = private_key
self._jwt_token = ""
self._jwt_token_exp = 0
self._lifetime = timedelta(
Expand Down Expand Up @@ -102,25 +103,32 @@ def prepare(

now = datetime.utcnow()

try:
private_key = load_der_private_key(
data=self._private_key,
password=None,
backend=default_backend(),
)
except Exception as e:
raise ProgrammingError(
msg=f"Failed to load private key: {e}\nPlease provide a valid "
"unencrypted rsa private key in DER format as bytes object",
errno=ER_INVALID_PRIVATE_KEY,
)
if isinstance(self._private_key, bytes):
try:
private_key = load_der_private_key(
data=self._private_key,
password=None,
backend=default_backend(),
)
except Exception as e:
raise ProgrammingError(
msg=f"Failed to load private key: {e}\nPlease provide a valid "
"unencrypted rsa private key in DER format as bytes object",
errno=ER_INVALID_PRIVATE_KEY,
)

if not isinstance(private_key, RSAPrivateKey):
raise ProgrammingError(
msg=f"Private key type ({private_key.__class__.__name__}) not supported."
"\nPlease provide a valid rsa private key in DER format as bytes "
"object",
errno=ER_INVALID_PRIVATE_KEY,
if not isinstance(private_key, RSAPrivateKey):
raise ProgrammingError(
msg=f"Private key type ({private_key.__class__.__name__}) not supported."
"\nPlease provide a valid rsa private key in DER format as bytes "
"object",
errno=ER_INVALID_PRIVATE_KEY,
)
elif isinstance(self._private_key, RSAPrivateKey):
private_key = self._private_key
else:
raise TypeError(
f"Expected bytes or RSAPrivateKey, got {type(self._private_key)}"
)

public_key_fp = self.calculate_public_key_fingerprint(private_key)
Expand Down
4 changes: 3 additions & 1 deletion src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from typing import Any, Callable, Generator, Iterable, NamedTuple, Sequence
from uuid import UUID

from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey

from . import errors, proxy
from ._query_context_cache import QueryContextCache
from .auth import (
Expand Down Expand Up @@ -146,7 +148,7 @@ def DefaultConverterClass() -> type:
), # network timeout (infinite by default)
"passcode_in_password": (False, bool), # Snowflake MFA
"passcode": (None, (type(None), str)), # Snowflake MFA
"private_key": (None, (type(None), str)),
"private_key": (None, (type(None), str, RSAPrivateKey)),
"token": (None, (type(None), str)), # OAuth or JWT Token
"authenticator": (DEFAULT_AUTHENTICATOR, (type(None), str)),
"mfa_callback": (None, (type(None), Callable)),
Expand Down
2 changes: 0 additions & 2 deletions test/integ/test_key_pair_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ def test_bad_private_key(db_parameters):
)

bad_private_key_test_cases = [
"abcd",
1234,
b"abcd",
dsa_private_key_der,
encrypted_rsa_private_key_der,
Expand Down
57 changes: 57 additions & 0 deletions test/unit/test_auth_keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.serialization import load_der_private_key
from pytest import raises

from snowflake.connector.auth import Auth
from snowflake.connector.constants import OCSPMode
Expand Down Expand Up @@ -59,6 +62,60 @@ def test_auth_keypair():
assert rest.master_token == "MASTER_TOKEN"


def test_auth_keypair_abc():
"""Simple Key Pair test using abstraction layer."""
private_key_der, public_key_der_encoded = generate_key_pair(2048)
application = "testapplication"
account = "testaccount"
user = "testuser"

private_key = load_der_private_key(
data=private_key_der,
password=None,
backend=default_backend(),
)

assert isinstance(private_key, RSAPrivateKey)

auth_instance = AuthByKeyPair(private_key=private_key)
auth_instance.handle_timeout(
authenticator="SNOWFLAKE_JWT",
service_name=None,
account=account,
user=user,
password=None,
)

# success test case
rest = _init_rest(application, _create_mock_auth_keypair_rest_response())
auth = Auth(rest)
auth.authenticate(auth_instance, account, user)
assert not rest._connection.errorhandler.called # not error
assert rest.token == "TOKEN"
assert rest.master_token == "MASTER_TOKEN"


def test_auth_keypair_bad_type():
"""Simple Key Pair test using abstraction layer."""
account = "testaccount"
user = "testuser"

class Bad:
pass

for bad_private_key in ("abcd", 1234, Bad()):
auth_instance = AuthByKeyPair(private_key=bad_private_key)
with raises(TypeError) as ex:
auth_instance.handle_timeout(
authenticator="SNOWFLAKE_JWT",
service_name=None,
account=account,
user=user,
password=None,
)
assert str(type(bad_private_key)) in str(ex)


def _init_rest(application, post_requset):
connection = MagicMock()
connection._login_timeout = 120
Expand Down

0 comments on commit db8e265

Please sign in to comment.