Skip to content

Commit

Permalink
Fixed oauth cache + added state check
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Dec 20, 2024
1 parent 284ed4f commit c99757b
Show file tree
Hide file tree
Showing 18 changed files with 439 additions and 147 deletions.
2 changes: 2 additions & 0 deletions src/main/java/net/snowflake/client/core/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public final class Constants {

public static final int OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE = 390318;

public static final int OAUTH_ACCESS_TOKEN_INVALID_GS_CODE = 390303;

// Error message for IOException when no space is left for GET
public static final String NO_SPACE_LEFT_ON_DEVICE_ERR = "No space left on device";

Expand Down
141 changes: 82 additions & 59 deletions src/main/java/net/snowflake/client/core/CredentialManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package net.snowflake.client.core;

import com.google.common.base.Strings;
import java.net.URI;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

Expand Down Expand Up @@ -34,19 +35,19 @@ private void initSecureStorageManager() {
}

/** Helper function for tests to go back to normal settings. */
void resetSecureStorageManager() {
static void resetSecureStorageManager() {
logger.debug("Resetting the secure storage manager");
initSecureStorageManager();
getInstance().initSecureStorageManager();
}

/**
* Testing purpose. Inject a mock manager.
*
* @param manager SecureStorageManager
*/
void injectSecureStorageManager(SecureStorageManager manager) {
static void injectSecureStorageManager(SecureStorageManager manager) {
logger.debug("Injecting secure storage manager");
secureStorageManager = manager;
getInstance().secureStorageManager = manager;
}

private static class CredentialManagerHolder {
Expand All @@ -67,7 +68,12 @@ static void fillCachedIdToken(SFLoginInput loginInput) throws SFException {
"Looking for cached id token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
getInstance().fillCachedCredential(loginInput, CachedCredentialType.ID_TOKEN);
getInstance()
.fillCachedCredential(
loginInput,
loginInput.getHostFromServerUrl(),
loginInput.getUserName(),
CachedCredentialType.ID_TOKEN);
}

/**
Expand All @@ -80,7 +86,12 @@ static void fillCachedMfaToken(SFLoginInput loginInput) throws SFException {
"Looking for cached mfa token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
getInstance().fillCachedCredential(loginInput, CachedCredentialType.MFA_TOKEN);
getInstance()
.fillCachedCredential(
loginInput,
loginInput.getHostFromServerUrl(),
loginInput.getUserName(),
CachedCredentialType.MFA_TOKEN);
}

/**
Expand All @@ -89,11 +100,14 @@ static void fillCachedMfaToken(SFLoginInput loginInput) throws SFException {
* @param loginInput login input to attach access token
*/
static void fillCachedOAuthAccessToken(SFLoginInput loginInput) throws SFException {
String host = getHostForOAuthCacheKey(loginInput);
logger.debug(
"Looking for cached OAuth access token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
getInstance().fillCachedCredential(loginInput, CachedCredentialType.OAUTH_ACCESS_TOKEN);
host);
getInstance()
.fillCachedCredential(
loginInput, host, loginInput.getUserName(), CachedCredentialType.OAUTH_ACCESS_TOKEN);
}

/**
Expand All @@ -102,31 +116,27 @@ static void fillCachedOAuthAccessToken(SFLoginInput loginInput) throws SFExcepti
* @param loginInput login input to attach refresh token
*/
static void fillCachedOAuthRefreshToken(SFLoginInput loginInput) throws SFException {
String host = getHostForOAuthCacheKey(loginInput);
logger.debug(
"Looking for cached OAuth refresh token for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
getInstance().fillCachedCredential(loginInput, CachedCredentialType.OAUTH_REFRESH_TOKEN);
host);
getInstance()
.fillCachedCredential(
loginInput, host, loginInput.getUserName(), CachedCredentialType.OAUTH_REFRESH_TOKEN);
}

/**
* Reuse the cached token stored locally
*
* @param loginInput login input to attach token
* @param credType credential type to retrieve
*/
synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredentialType credType)
throws SFException {
/** Reuse the cached token stored locally */
synchronized void fillCachedCredential(
SFLoginInput loginInput, String host, String username, CachedCredentialType credType) {
if (secureStorageManager == null) {
logMissingJnaJarForSecureLocalStorage();
return;
}

String cred;
try {
cred =
secureStorageManager.getCredential(
loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType.getValue());
cred = secureStorageManager.getCredential(host, username, credType.getValue());
} catch (NoClassDefFoundError error) {
logMissingJnaJarForSecureLocalStorage();
return;
Expand All @@ -140,8 +150,8 @@ synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredential
"Setting {}{} token for user: {}, host: {}",
cred == null ? "null " : "",
credType.getValue(),
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
username,
host);
switch (credType) {
case ID_TOKEN:
loginInput.setIdToken(cred);
Expand All @@ -161,36 +171,30 @@ synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredential
}
}

/**
* Store ID Token
*
* @param loginInput loginInput to denote to the cache
* @param loginOutput loginOutput to denote to the cache
*/
static void writeIdToken(SFLoginInput loginInput, SFLoginOutput loginOutput) throws SFException {
static void writeIdToken(SFLoginInput loginInput, String idToken) throws SFException {
logger.debug(
"Caching id token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
getInstance()
.writeTemporaryCredential(
loginInput, loginOutput.getIdToken(), CachedCredentialType.ID_TOKEN);
loginInput.getHostFromServerUrl(),
loginInput.getUserName(),
idToken,
CachedCredentialType.ID_TOKEN);
}

/**
* Store MFA Token
*
* @param loginInput loginInput to denote to the cache
* @param loginOutput loginOutput to denote to the cache
*/
static void writeMfaToken(SFLoginInput loginInput, SFLoginOutput loginOutput) throws SFException {
static void writeMfaToken(SFLoginInput loginInput, String mfaToken) throws SFException {
logger.debug(
"Caching mfa token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
getInstance()
.writeTemporaryCredential(
loginInput, loginOutput.getMfaToken(), CachedCredentialType.MFA_TOKEN);
loginInput.getHostFromServerUrl(),
loginInput.getUserName(),
mfaToken,
CachedCredentialType.MFA_TOKEN);
}

/**
Expand All @@ -199,13 +203,17 @@ static void writeMfaToken(SFLoginInput loginInput, SFLoginOutput loginOutput) th
* @param loginInput loginInput to denote to the cache
*/
static void writeOAuthAccessToken(SFLoginInput loginInput) throws SFException {
String host = getHostForOAuthCacheKey(loginInput);
logger.debug(
"Caching OAuth access token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
host);
getInstance()
.writeTemporaryCredential(
loginInput, loginInput.getOauthAccessToken(), CachedCredentialType.OAUTH_ACCESS_TOKEN);
host,
loginInput.getUserName(),
loginInput.getOauthAccessToken(),
CachedCredentialType.OAUTH_ACCESS_TOKEN);
}

/**
Expand All @@ -214,26 +222,22 @@ static void writeOAuthAccessToken(SFLoginInput loginInput) throws SFException {
* @param loginInput loginInput to denote to the cache
*/
static void writeOAuthRefreshToken(SFLoginInput loginInput) throws SFException {
String host = getHostForOAuthCacheKey(loginInput);
logger.debug(
"Caching OAuth refresh token in a secure storage for user: {}, host: {}",
loginInput.getUserName(),
loginInput.getHostFromServerUrl());
host);
getInstance()
.writeTemporaryCredential(
loginInput,
host,
loginInput.getUserName(),
loginInput.getOauthRefreshToken(),
CachedCredentialType.OAUTH_REFRESH_TOKEN);
}

/**
* Store the temporary credential
*
* @param loginInput loginInput to denote to the cache
* @param cred the credential
* @param credType type of the credential
*/
/** Store the temporary credential */
synchronized void writeTemporaryCredential(
SFLoginInput loginInput, String cred, CachedCredentialType credType) throws SFException {
String host, String user, String cred, CachedCredentialType credType) {
if (Strings.isNullOrEmpty(cred)) {
logger.debug("No {} is given.", credType);
return; // no credential
Expand All @@ -245,8 +249,7 @@ synchronized void writeTemporaryCredential(
}

try {
secureStorageManager.setCredential(
loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType.getValue(), cred);
secureStorageManager.setCredential(host, user, credType.getValue(), cred);
} catch (NoClassDefFoundError error) {
logMissingJnaJarForSecureLocalStorage();
}
Expand All @@ -267,21 +270,41 @@ static void deleteMfaTokenCache(String host, String user) {
}

/** Delete the OAuth access token cache */
static void deleteOAuthAccessTokenCache(String host, String user) {
static void deleteOAuthAccessTokenCache(SFLoginInput loginInput) throws SFException {
String host = getHostForOAuthCacheKey(loginInput);
logger.debug(
"Removing cached OAuth access token from a secure storage for user: {}, host: {}",
user,
loginInput.getUserName(),
host);
getInstance().deleteTemporaryCredential(host, user, CachedCredentialType.OAUTH_ACCESS_TOKEN);
getInstance()
.deleteTemporaryCredential(
host, loginInput.getUserName(), CachedCredentialType.OAUTH_ACCESS_TOKEN);
}

/** Delete the OAuth refresh token cache */
static void deleteOAuthRefreshTokenCache(String host, String user) {
static void deleteOAuthRefreshTokenCache(SFLoginInput loginInput) throws SFException {
String host = getHostForOAuthCacheKey(loginInput);
logger.debug(
"Removing cached OAuth refresh token from a secure storage for user: {}, host: {}",
user,
loginInput.getUserName(),
host);
getInstance().deleteTemporaryCredential(host, user, CachedCredentialType.OAUTH_REFRESH_TOKEN);
getInstance()
.deleteTemporaryCredential(
host, loginInput.getUserName(), CachedCredentialType.OAUTH_REFRESH_TOKEN);
}

/**
* Method required for OAuth token caching, since actual token is not Snowflake account-specific,
* but rather IdP-specific
*/
static String getHostForOAuthCacheKey(SFLoginInput loginInput) throws SFException {
String externalTokenRequestUrl = loginInput.getOauthLoginInput().getExternalTokenRequestUrl();
if (externalTokenRequestUrl != null) {
URI parsedUrl = URI.create(externalTokenRequestUrl);
return parsedUrl.getHost();
} else {
return loginInput.getHostFromServerUrl();
}
}

/**
Expand Down
37 changes: 18 additions & 19 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ static SFLoginOutput openSession(
readCachedTokens(loginInput);

if (OAuthAccessTokenProviderFactory.isEligible(getAuthenticator(loginInput))) {
updateInputWithOAuthAccessToken(loginInput);
obtainAuthAccessTokenAndUpdateInput(loginInput);
}

try {
Expand All @@ -339,23 +339,24 @@ static SFLoginOutput openSession(
refreshOAuthAccessTokenAndUpdateInput(loginInput);
} else {
loginInput.setAuthenticator(loginInput.getOriginAuthenticator());
obtainOAuthAccessTokenAndUpdateInput(loginInput);
fetchOAuthAccessTokenAndUpdateInput(loginInput);
}
}
return newSession(loginInput, connectionPropertiesMap, tracingLevel);
}
}

private static void updateInputWithOAuthAccessToken(SFLoginInput loginInput) throws SFException {
if (loginInput.getOauthAccessToken() != null) { // Access Token cached
private static void obtainAuthAccessTokenAndUpdateInput(SFLoginInput loginInput)
throws SFException {
if (loginInput.getOauthAccessToken() != null) { // Access Token was cached
loginInput.setAuthenticator(AuthenticatorType.OAUTH.name());
loginInput.setToken(loginInput.getOauthAccessToken());
} else { // Access Token not cached
obtainOAuthAccessTokenAndUpdateInput(loginInput);
fetchOAuthAccessTokenAndUpdateInput(loginInput);
}
}

private static void obtainOAuthAccessTokenAndUpdateInput(SFLoginInput loginInput)
private static void fetchOAuthAccessTokenAndUpdateInput(SFLoginInput loginInput)
throws SFException {
OAuthAccessTokenProviderFactory accessTokenProviderFactory =
new OAuthAccessTokenProviderFactory(
Expand Down Expand Up @@ -387,9 +388,9 @@ private static void refreshOAuthAccessTokenAndUpdateInput(SFLoginInput loginInpu
logger.debug(
"Refreshing OAuth access token failed. Removing OAuth refresh token from cache and restarting OAuth flow...",
e);
deleteOAuthRefreshTokenCache(loginInput.getHostFromServerUrl(), loginInput.getUserName());
CredentialManager.deleteOAuthRefreshTokenCache(loginInput);
loginInput.setAuthenticator(loginInput.getOriginAuthenticator());
obtainOAuthAccessTokenAndUpdateInput(loginInput);
fetchOAuthAccessTokenAndUpdateInput(loginInput);
}
}

Expand Down Expand Up @@ -860,9 +861,15 @@ static SFLoginOutput newSession(
SnowflakeUtil.checkErrorAndThrowExceptionIncludingReauth(jsonNode);
}

if (errorCode == Constants.OAUTH_ACCESS_TOKEN_INVALID_GS_CODE) {
logger.debug("OAuth Access Token Invalid: {}", errorCode);
loginInput.setOauthAccessToken(null);
CredentialManager.deleteOAuthAccessTokenCache(loginInput);
}

if (errorCode == Constants.OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE) {
loginInput.setOauthAccessToken(null);
deleteOAuthAccessTokenCache(loginInput.getHostFromServerUrl(), loginInput.getUserName());
CredentialManager.deleteOAuthAccessTokenCache(loginInput);

logger.debug("OAuth Access Token Expired: {}", errorCode);
SnowflakeUtil.checkErrorAndThrowExceptionIncludingReauth(jsonNode);
Expand Down Expand Up @@ -1007,7 +1014,7 @@ static SFLoginOutput newSession(

if (asBoolean(loginInput.getSessionParameters().get(CLIENT_STORE_TEMPORARY_CREDENTIAL))) {
if (consentCacheIdToken) {
CredentialManager.writeIdToken(loginInput, ret);
CredentialManager.writeIdToken(loginInput, ret.getIdToken());
}
if (loginInput.getOauthAccessToken() != null) {
CredentialManager.writeOAuthAccessToken(loginInput);
Expand All @@ -1018,7 +1025,7 @@ static SFLoginOutput newSession(
}

if (asBoolean(loginInput.getSessionParameters().get(CLIENT_REQUEST_MFA_TOKEN))) {
CredentialManager.writeMfaToken(loginInput, ret);
CredentialManager.writeMfaToken(loginInput, ret.getMfaToken());
}

stopwatch.stop();
Expand Down Expand Up @@ -1065,14 +1072,6 @@ public static void deleteMfaTokenCache(String host, String user) {
CredentialManager.deleteMfaTokenCache(host, user);
}

private static void deleteOAuthAccessTokenCache(String host, String user) {
CredentialManager.deleteOAuthAccessTokenCache(host, user);
}

private static void deleteOAuthRefreshTokenCache(String host, String user) {
CredentialManager.deleteOAuthRefreshTokenCache(host, user);
}

/**
* Renew a session.
*
Expand Down
Loading

0 comments on commit c99757b

Please sign in to comment.