Skip to content

Commit

Permalink
feat: add private_key_passphrase keyword.
Browse files Browse the repository at this point in the history
  • Loading branch information
yassun7010 committed Jan 9, 2025
1 parent 138241c commit 75b11d6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/snowflake/connector/auth/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class AuthByKeyPair(AuthByPlugin):
def __init__(
self,
private_key: bytes | RSAPrivateKey,
private_key_passphrase: bytes | None = None,
lifetime_in_seconds: int = LIFETIME,
**kwargs,
) -> None:
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
)

self._private_key: bytes | RSAPrivateKey | None = private_key
self._private_key_passphrase: bytes | None = private_key_passphrase
self._jwt_token = ""
self._jwt_token_exp = 0
self._lifetime = timedelta(
Expand Down Expand Up @@ -109,7 +111,7 @@ def prepare(
try:
private_key = load_der_private_key(
data=self._private_key,
password=None,
password=self._private_key_passphrase,
backend=default_backend(),
)
except Exception as e:
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _get_private_bytes_from_file(
"passcode_in_password": (False, bool), # Snowflake MFA
"passcode": (None, (type(None), str)), # Snowflake MFA
"private_key": (None, (type(None), bytes, RSAPrivateKey)),
"private_key_passphrase": (None, (type(None), bytes)),
"private_key_file": (None, (type(None), str)),
"private_key_file_pwd": (None, (type(None), str, bytes)),
"token": (None, (type(None), str)), # OAuth/JWT/PAT Token
Expand Down Expand Up @@ -1069,6 +1070,7 @@ def __open_connection(self):

elif self._authenticator == KEY_PAIR_AUTHENTICATOR:
private_key = self._private_key
private_key_passphrase = self._private_key_passphrase

if self._private_key_file:
private_key = _get_private_bytes_from_file(
Expand All @@ -1078,6 +1080,7 @@ def __open_connection(self):

self.auth_class = AuthByKeyPair(
private_key=private_key,
private_key_passphrase=private_key_passphrase,
timeout=self.login_timeout,
backoff_generator=self._backoff_generator,
)
Expand Down
41 changes: 39 additions & 2 deletions test/unit/test_auth_keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,39 @@ def test_auth_keypair():
assert rest.master_token == "MASTER_TOKEN"


def test_auth_keypair_with_passphrase():
"""Simple Key Pair test with passphrase."""

passphrase = b"test"
private_key_der, public_key_der_encoded = generate_key_pair(
2048,
passphrase=passphrase,
)
application = "testapplication"
account = "testaccount"
user = "testuser"
auth_instance = AuthByKeyPair(
private_key=private_key_der,
private_key_passphrase=passphrase,
)
auth_instance._retry_ctx.set_start_time()
auth_instance.handle_timeout(
authenticator="SNOWFLAKE_JWT",
service_name=None,
account=account,
user=user,
password=None,
)

# success test case
rest = _init_rest(application, _create_mock_auth_keypair_rest_response())
auth = Auth(rest)
auth.authenticate(auth_instance, account, user)
assert not rest._connection.errorhandler.called # not error
assert rest.token == "TOKEN"
assert rest.master_token == "MASTER_TOKEN"


def test_auth_keypair_abc():
"""Simple Key Pair test using abstraction layer."""
private_key_der, public_key_der_encoded = generate_key_pair(2048)
Expand Down Expand Up @@ -153,15 +186,19 @@ def _init_rest(application, post_requset):
return rest


def generate_key_pair(key_length):
def generate_key_pair(key_length: int, *, passphrase: bytes | None = None):
private_key = rsa.generate_private_key(
backend=default_backend(), public_exponent=65537, key_size=key_length
)

private_key_der = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
encryption_algorithm=(
serialization.BestAvailableEncryption(passphrase)
if passphrase
else serialization.NoEncryption()
),
)

public_key_pem = (
Expand Down

0 comments on commit 75b11d6

Please sign in to comment.