From 75b11d64e2b89859c090773485a2624eb309fd22 Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Thu, 9 Jan 2025 20:03:07 +0900 Subject: [PATCH] feat: add private_key_passphrase keyword. --- src/snowflake/connector/auth/keypair.py | 4 ++- src/snowflake/connector/connection.py | 3 ++ test/unit/test_auth_keypair.py | 41 +++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/snowflake/connector/auth/keypair.py b/src/snowflake/connector/auth/keypair.py index a5d658666..f94f339e7 100644 --- a/src/snowflake/connector/auth/keypair.py +++ b/src/snowflake/connector/auth/keypair.py @@ -44,6 +44,7 @@ class AuthByKeyPair(AuthByPlugin): def __init__( self, private_key: bytes | RSAPrivateKey, + private_key_passphrase: bytes | None = None, lifetime_in_seconds: int = LIFETIME, **kwargs, ) -> None: @@ -76,6 +77,7 @@ def __init__( ) self._private_key: bytes | RSAPrivateKey | None = private_key + self._private_key_passphrase: bytes | None = private_key_passphrase self._jwt_token = "" self._jwt_token_exp = 0 self._lifetime = timedelta( @@ -109,7 +111,7 @@ def prepare( try: private_key = load_der_private_key( data=self._private_key, - password=None, + password=self._private_key_passphrase, backend=default_backend(), ) except Exception as e: diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index c51c33c60..6927ebb42 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -185,6 +185,7 @@ def _get_private_bytes_from_file( "passcode_in_password": (False, bool), # Snowflake MFA "passcode": (None, (type(None), str)), # Snowflake MFA "private_key": (None, (type(None), bytes, RSAPrivateKey)), + "private_key_passphrase": (None, (type(None), bytes)), "private_key_file": (None, (type(None), str)), "private_key_file_pwd": (None, (type(None), str, bytes)), "token": (None, (type(None), str)), # OAuth/JWT/PAT Token @@ -1069,6 +1070,7 @@ def __open_connection(self): elif self._authenticator == KEY_PAIR_AUTHENTICATOR: private_key = self._private_key + private_key_passphrase = self._private_key_passphrase if self._private_key_file: private_key = _get_private_bytes_from_file( @@ -1078,6 +1080,7 @@ def __open_connection(self): self.auth_class = AuthByKeyPair( private_key=private_key, + private_key_passphrase=private_key_passphrase, timeout=self.login_timeout, backoff_generator=self._backoff_generator, ) diff --git a/test/unit/test_auth_keypair.py b/test/unit/test_auth_keypair.py index c019ca0c1..dcfedc3f7 100644 --- a/test/unit/test_auth_keypair.py +++ b/test/unit/test_auth_keypair.py @@ -65,6 +65,39 @@ def test_auth_keypair(): assert rest.master_token == "MASTER_TOKEN" +def test_auth_keypair_with_passphrase(): + """Simple Key Pair test with passphrase.""" + + passphrase = b"test" + private_key_der, public_key_der_encoded = generate_key_pair( + 2048, + passphrase=passphrase, + ) + application = "testapplication" + account = "testaccount" + user = "testuser" + auth_instance = AuthByKeyPair( + private_key=private_key_der, + private_key_passphrase=passphrase, + ) + auth_instance._retry_ctx.set_start_time() + auth_instance.handle_timeout( + authenticator="SNOWFLAKE_JWT", + service_name=None, + account=account, + user=user, + password=None, + ) + + # success test case + rest = _init_rest(application, _create_mock_auth_keypair_rest_response()) + auth = Auth(rest) + auth.authenticate(auth_instance, account, user) + assert not rest._connection.errorhandler.called # not error + assert rest.token == "TOKEN" + assert rest.master_token == "MASTER_TOKEN" + + def test_auth_keypair_abc(): """Simple Key Pair test using abstraction layer.""" private_key_der, public_key_der_encoded = generate_key_pair(2048) @@ -153,7 +186,7 @@ def _init_rest(application, post_requset): return rest -def generate_key_pair(key_length): +def generate_key_pair(key_length: int, *, passphrase: bytes | None = None): private_key = rsa.generate_private_key( backend=default_backend(), public_exponent=65537, key_size=key_length ) @@ -161,7 +194,11 @@ def generate_key_pair(key_length): private_key_der = private_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), + encryption_algorithm=( + serialization.BestAvailableEncryption(passphrase) + if passphrase + else serialization.NoEncryption() + ), ) public_key_pem = (