Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1831103/SNOW-1853435: Refresh token & OAuth tokens caching support #2009

Merged
merged 10 commits into from
Dec 20, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.client.core;

enum CachedCredentialType {
ID_TOKEN("ID_TOKEN"),
MFA_TOKEN("MFATOKEN"),
OAUTH_ACCESS_TOKEN("OAUTH_ACCESS_TOKEN"),
OAUTH_REFRESH_TOKEN("OAUTH_REFRESH_TOKEN");

private final String value;

CachedCredentialType(String value) {
this.value = value;
}

String getValue() {
return value;
}
}
2 changes: 2 additions & 0 deletions src/main/java/net/snowflake/client/core/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public final class Constants {
// Error code for all invalid id token cases during login request
public static final int ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE = 390195;

public static final int OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE = 390318;

// Error message for IOException when no space is left for GET
public static final String NO_SPACE_LEFT_ON_DEVICE_ERR = "No space left on device";

Expand Down
157 changes: 123 additions & 34 deletions src/main/java/net/snowflake/client/core/CredentialManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ public class CredentialManager {

private SecureStorageManager secureStorageManager;

private static final String ID_TOKEN = "ID_TOKEN";
private static final String MFA_TOKEN = "MFATOKEN";

private CredentialManager() {
initSecureStorageManager();
}
Expand Down Expand Up @@ -70,7 +67,7 @@ void fillCachedIdToken(SFLoginInput loginInput) throws SFException {
"Looking for cached id token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
fillCachedCredential(loginInput, ID_TOKEN);
fillCachedCredential(loginInput, CachedCredentialType.ID_TOKEN);
}

/**
Expand All @@ -83,7 +80,33 @@ void fillCachedMfaToken(SFLoginInput loginInput) throws SFException {
"Looking for cached mfa token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
fillCachedCredential(loginInput, MFA_TOKEN);
fillCachedCredential(loginInput, CachedCredentialType.MFA_TOKEN);
}

/**
* Reuse the cached OAuth access token stored locally
*
* @param loginInput login input to attach access token
*/
void fillCachedOAuthAccessToken(SFLoginInput loginInput) throws SFException {
logger.debug(
"Looking for cached OAuth access token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
fillCachedCredential(loginInput, CachedCredentialType.OAUTH_ACCESS_TOKEN);
}

/**
* Reuse the cached OAuth refresh token stored locally
*
* @param loginInput login input to attach refresh token
*/
void fillCachedOAuthRefreshToken(SFLoginInput loginInput) throws SFException {
logger.debug(
sfc-gh-astachowski marked this conversation as resolved.
Show resolved Hide resolved
"Looking for cached OAuth refresh token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
fillCachedCredential(loginInput, CachedCredentialType.OAUTH_REFRESH_TOKEN);
}

/**
Expand All @@ -92,18 +115,18 @@ void fillCachedMfaToken(SFLoginInput loginInput) throws SFException {
* @param loginInput login input to attach token
* @param credType credential type to retrieve
*/
synchronized void fillCachedCredential(SFLoginInput loginInput, String credType)
synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredentialType credType)
throws SFException {
if (secureStorageManager == null) {
logMissingJnaJarForSecureLocalStorage();
return;
}

String cred = null;
String cred;
try {
cred =
secureStorageManager.getCredential(
loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType);
loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType.getValue());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is correct for oauth. Guessing from the code above getHostFromServerUrl returns snowflake host, right? For OAuth it should be IDP address. Otherwise we could accidentally leak tokens between IDPs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored according to your suggestion.

} catch (NoClassDefFoundError error) {
logMissingJnaJarForSecureLocalStorage();
return;
Expand All @@ -114,24 +137,43 @@ synchronized void fillCachedCredential(SFLoginInput loginInput, String credType)
}

// cred can be null
if (credType == ID_TOKEN) {
logger.debug(
"Setting {}id token for user: {}, host: {}",
cred == null ? "null " : "",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
loginInput.setIdToken(cred);
} else if (credType == MFA_TOKEN) {
logger.debug(
"Setting {}mfa token for user: {}, host: {}",
cred == null ? "null " : "",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
loginInput.setMfaToken(cred);
} else {
logger.debug("Unrecognized type {} for local cached credential", credType);
switch (credType) {
case ID_TOKEN:
logger.debug(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the messages are different only in the part after {} - could we make the login before and once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"Setting {}id token for user: {}, host: {}",
cred == null ? "null " : "",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
loginInput.setIdToken(cred);
break;
case MFA_TOKEN:
logger.debug(
"Setting {}mfa token for user: {}, host: {}",
cred == null ? "null " : "",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
loginInput.setMfaToken(cred);
break;
case OAUTH_ACCESS_TOKEN:
logger.debug(
"Setting {}OAuth access token for user: {}, host: {}",
cred == null ? "null " : "",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
loginInput.setOauthAccessToken(cred);
break;
case OAUTH_REFRESH_TOKEN:
logger.debug(
"Setting {}OAuth refresh token for user: {}, host: {}",
cred == null ? "null " : "",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
loginInput.setOauthRefreshToken(cred);
break;
default:
logger.debug("Unrecognized type {} for local cached credential", credType);
break;
}
return;
}

/**
Expand All @@ -145,7 +187,7 @@ void writeIdToken(SFLoginInput loginInput, SFLoginOutput loginOutput) throws SFE
"Caching id token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
writeTemporaryCredential(loginInput, loginOutput.getIdToken(), ID_TOKEN);
writeTemporaryCredential(loginInput, loginOutput.getIdToken(), CachedCredentialType.ID_TOKEN);
}

/**
Expand All @@ -159,7 +201,35 @@ void writeMfaToken(SFLoginInput loginInput, SFLoginOutput loginOutput) throws SF
"Caching mfa token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
writeTemporaryCredential(loginInput, loginOutput.getMfaToken(), MFA_TOKEN);
writeTemporaryCredential(loginInput, loginOutput.getMfaToken(), CachedCredentialType.MFA_TOKEN);
}

/**
* Store OAuth Access Token
*
* @param loginInput loginInput to denote to the cache
*/
void writeOAuthAccessToken(SFLoginInput loginInput) throws SFException {
logger.debug(
"Caching OAuth access token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
writeTemporaryCredential(
loginInput, loginInput.getOauthAccessToken(), CachedCredentialType.OAUTH_ACCESS_TOKEN);
}

/**
* Store OAuth Refresh Token
*
* @param loginInput loginInput to denote to the cache
*/
void writeOAuthRefreshToken(SFLoginInput loginInput) throws SFException {
logger.debug(
sfc-gh-astachowski marked this conversation as resolved.
Show resolved Hide resolved
"Caching OAuth refresh token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
writeTemporaryCredential(
loginInput, loginInput.getOauthRefreshToken(), CachedCredentialType.OAUTH_REFRESH_TOKEN);
}

/**
Expand All @@ -169,8 +239,8 @@ void writeMfaToken(SFLoginInput loginInput, SFLoginOutput loginOutput) throws SF
* @param cred the credential
* @param credType type of the credential
*/
synchronized void writeTemporaryCredential(SFLoginInput loginInput, String cred, String credType)
throws SFException {
synchronized void writeTemporaryCredential(
SFLoginInput loginInput, String cred, CachedCredentialType credType) throws SFException {
if (Strings.isNullOrEmpty(cred)) {
logger.debug("No {} is given.", credType);
return; // no credential
Expand All @@ -183,7 +253,7 @@ synchronized void writeTemporaryCredential(SFLoginInput loginInput, String cred,

try {
secureStorageManager.setCredential(
loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType, cred);
loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType.getValue(), cred);
} catch (NoClassDefFoundError error) {
logMissingJnaJarForSecureLocalStorage();
}
Expand All @@ -193,14 +263,32 @@ synchronized void writeTemporaryCredential(SFLoginInput loginInput, String cred,
void deleteIdTokenCache(String host, String user) {
logger.debug(
"Removing cached id token from a secure storage for user: {}, host: {}", user, host);
deleteTemporaryCredential(host, user, ID_TOKEN);
deleteTemporaryCredential(host, user, CachedCredentialType.ID_TOKEN);
}

/** Delete the mfa token cache */
void deleteMfaTokenCache(String host, String user) {
logger.debug(
"Removing cached mfa token from a secure storage for user: {}, host: {}", user, host);
deleteTemporaryCredential(host, user, MFA_TOKEN);
deleteTemporaryCredential(host, user, CachedCredentialType.MFA_TOKEN);
}

/** Delete the OAuth access token cache */
void deleteOAuthAccessTokenCache(String host, String user) {
logger.debug(
"Removing cached OAuth access token from a secure storage for user: {}, host: {}",
user,
host);
deleteTemporaryCredential(host, user, CachedCredentialType.OAUTH_ACCESS_TOKEN);
}

/** Delete the OAuth refresh token cache */
void deleteOAuthRefreshTokenCache(String host, String user) {
logger.debug(
"Removing cached OAuth refresh token from a secure storage for user: {}, host: {}",
user,
host);
deleteTemporaryCredential(host, user, CachedCredentialType.OAUTH_REFRESH_TOKEN);
}

/**
Expand All @@ -210,14 +298,15 @@ void deleteMfaTokenCache(String host, String user) {
* @param user user name
* @param credType type of the credential
*/
synchronized void deleteTemporaryCredential(String host, String user, String credType) {
synchronized void deleteTemporaryCredential(
String host, String user, CachedCredentialType credType) {
if (secureStorageManager == null) {
logMissingJnaJarForSecureLocalStorage();
return;
}

try {
secureStorageManager.deleteCredential(host, user, credType);
secureStorageManager.deleteCredential(host, user, credType.getValue());
} catch (NoClassDefFoundError error) {
logMissingJnaJarForSecureLocalStorage();
}
Expand Down
36 changes: 33 additions & 3 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class SFLoginInput {
private String warehouse;
private String role;
private boolean validateDefaultParameters;
private String originAuthenticator;
private String authenticator;
private String oktaUserName;
private String accountName;
Expand All @@ -41,6 +42,8 @@ public class SFLoginInput {
private String application;
private String idToken;
private String mfaToken;
private String oauthAccessToken;
private String oauthRefreshToken;
private String serviceName;
private OCSPMode ocspMode;
private HttpClientSettingsKey httpClientKey;
Expand Down Expand Up @@ -317,6 +320,25 @@ SFLoginInput setMfaToken(String mfaToken) {
return this;
}

String getOauthAccessToken() {
return oauthAccessToken;
}

SFLoginInput setOauthAccessToken(String oauthAccessToken) {
this.oauthAccessToken = oauthAccessToken;
return this;
}

@SnowflakeJdbcInternalApi
public String getOauthRefreshToken() {
return oauthRefreshToken;
}

SFLoginInput setOauthRefreshToken(String oauthRefreshToken) {
this.oauthRefreshToken = oauthRefreshToken;
return this;
}

Map<String, Object> getSessionParameters() {
return sessionParameters;
}
Expand Down Expand Up @@ -404,13 +426,13 @@ SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) {
this.httpClientKey = key;
return this;
}

// Opaque string sent for Snowsight account activation

String getInFlightCtx() {
return inFlightCtx;
}

// Opaque string sent for Snowsight account activation

SFLoginInput setInFlightCtx(String inFlightCtx) {
this.inFlightCtx = inFlightCtx;
return this;
Expand All @@ -428,7 +450,6 @@ SFLoginInput setDisableSamlURLCheck(boolean disableSamlURLCheck) {
Map<String, String> getAdditionalHttpHeadersForSnowsight() {
return additionalHttpHeadersForSnowsight;
}

/**
* Set additional http headers to apply to the outgoing request. The additional headers cannot be
* used to replace or overwrite a header in use by the driver. These will be applied to the
Expand Down Expand Up @@ -504,4 +525,13 @@ public SFLoginInput setOauthLoginInput(SFOauthLoginInput oauthLoginInput) {
this.oauthLoginInput = oauthLoginInput;
return this;
}

String getOriginAuthenticator() {
return originAuthenticator;
}

SFLoginInput setOriginAuthenticator(String originAuthenticator) {
this.originAuthenticator = originAuthenticator;
return this;
}
}
Loading
Loading