Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add private_key_passphrase keyword. #2131

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/snowflake/connector/auth/keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class AuthByKeyPair(AuthByPlugin):
def __init__(
self,
private_key: bytes | RSAPrivateKey,
private_key_passphrase: bytes | None = None,
lifetime_in_seconds: int = LIFETIME,
**kwargs,
) -> None:
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
)

self._private_key: bytes | RSAPrivateKey | None = private_key
self._private_key_passphrase: bytes | None = private_key_passphrase
self._jwt_token = ""
self._jwt_token_exp = 0
self._lifetime = timedelta(
Expand Down Expand Up @@ -109,7 +111,7 @@ def prepare(
try:
private_key = load_der_private_key(
data=self._private_key,
password=None,
password=self._private_key_passphrase,
backend=default_backend(),
)
except Exception as e:
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _get_private_bytes_from_file(
"passcode_in_password": (False, bool), # Snowflake MFA
"passcode": (None, (type(None), str)), # Snowflake MFA
"private_key": (None, (type(None), bytes, RSAPrivateKey)),
"private_key_passphrase": (None, (type(None), bytes)),
"private_key_file": (None, (type(None), str)),
"private_key_file_pwd": (None, (type(None), str, bytes)),
"token": (None, (type(None), str)), # OAuth/JWT/PAT Token
Expand Down Expand Up @@ -1069,6 +1070,7 @@ def __open_connection(self):

elif self._authenticator == KEY_PAIR_AUTHENTICATOR:
private_key = self._private_key
private_key_passphrase = self._private_key_passphrase

if self._private_key_file:
private_key = _get_private_bytes_from_file(
Expand All @@ -1078,6 +1080,7 @@ def __open_connection(self):

self.auth_class = AuthByKeyPair(
private_key=private_key,
private_key_passphrase=private_key_passphrase,
timeout=self.login_timeout,
backoff_generator=self._backoff_generator,
)
Expand Down
41 changes: 39 additions & 2 deletions test/unit/test_auth_keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,39 @@ def test_auth_keypair():
assert rest.master_token == "MASTER_TOKEN"


def test_auth_keypair_with_passphrase():
"""Simple Key Pair test with passphrase."""

passphrase = b"test"
private_key_der, public_key_der_encoded = generate_key_pair(
2048,
passphrase=passphrase,
)
application = "testapplication"
account = "testaccount"
user = "testuser"
auth_instance = AuthByKeyPair(
private_key=private_key_der,
private_key_passphrase=passphrase,
)
auth_instance._retry_ctx.set_start_time()
auth_instance.handle_timeout(
authenticator="SNOWFLAKE_JWT",
service_name=None,
account=account,
user=user,
password=None,
)

# success test case
rest = _init_rest(application, _create_mock_auth_keypair_rest_response())
auth = Auth(rest)
auth.authenticate(auth_instance, account, user)
assert not rest._connection.errorhandler.called # not error
assert rest.token == "TOKEN"
assert rest.master_token == "MASTER_TOKEN"


def test_auth_keypair_abc():
"""Simple Key Pair test using abstraction layer."""
private_key_der, public_key_der_encoded = generate_key_pair(2048)
Expand Down Expand Up @@ -153,15 +186,19 @@ def _init_rest(application, post_requset):
return rest


def generate_key_pair(key_length):
def generate_key_pair(key_length: int, *, passphrase: bytes | None = None):
private_key = rsa.generate_private_key(
backend=default_backend(), public_exponent=65537, key_size=key_length
)

private_key_der = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
encryption_algorithm=(
serialization.BestAvailableEncryption(passphrase)
if passphrase
else serialization.NoEncryption()
),
)

public_key_pem = (
Expand Down
Loading