From c99757b95289364f8a08034fe8a0c858cd16cebb Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Fri, 20 Dec 2024 15:35:05 +0100 Subject: [PATCH] Fixed oauth cache + added state check --- .../net/snowflake/client/core/Constants.java | 2 + .../client/core/CredentialManager.java | 141 +++++++------ .../snowflake/client/core/SessionUtil.java | 37 ++-- ...uthAccessTokenForRefreshTokenProvider.java | 2 +- .../OAuthAccessTokenProviderFactory.java | 2 +- ...hAuthorizationCodeAccessTokenProvider.java | 69 ++++-- ...hClientCredentialsAccessTokenProvider.java | 2 +- .../core/auth/oauth/RandomStateProvider.java | 16 ++ .../client/core/auth/oauth/StateProvider.java | 12 ++ .../snowflake/client/AbstractDriverIT.java | 2 - .../client/core/CredentialManagerTest.java | 197 ++++++++++++++++++ .../OAuthAuthorizationCodeFlowLatestIT.java | 39 +++- .../client/core/OAuthTokenCacheLatestIT.java | 38 +--- .../client/core/SnowflakeMFACacheTest.java | 4 +- .../external_idp_custom_urls.json | 2 +- .../invalid_state_error.json | 17 ++ .../authorization_code/successful_flow.json | 2 +- .../token_request_error.json | 2 +- 18 files changed, 439 insertions(+), 147 deletions(-) create mode 100644 src/main/java/net/snowflake/client/core/auth/oauth/RandomStateProvider.java create mode 100644 src/main/java/net/snowflake/client/core/auth/oauth/StateProvider.java create mode 100644 src/test/java/net/snowflake/client/core/CredentialManagerTest.java create mode 100644 src/test/resources/wiremock/mappings/oauth/authorization_code/invalid_state_error.json diff --git a/src/main/java/net/snowflake/client/core/Constants.java b/src/main/java/net/snowflake/client/core/Constants.java index 7678940d5..2d9591b96 100644 --- a/src/main/java/net/snowflake/client/core/Constants.java +++ b/src/main/java/net/snowflake/client/core/Constants.java @@ -24,6 +24,8 @@ public final class Constants { public static final int OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE = 390318; + public static final int OAUTH_ACCESS_TOKEN_INVALID_GS_CODE = 390303; + // Error message for IOException when no space is left for GET public static final String NO_SPACE_LEFT_ON_DEVICE_ERR = "No space left on device"; diff --git a/src/main/java/net/snowflake/client/core/CredentialManager.java b/src/main/java/net/snowflake/client/core/CredentialManager.java index 9bf58944b..37e3e7d48 100644 --- a/src/main/java/net/snowflake/client/core/CredentialManager.java +++ b/src/main/java/net/snowflake/client/core/CredentialManager.java @@ -5,6 +5,7 @@ package net.snowflake.client.core; import com.google.common.base.Strings; +import java.net.URI; import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; @@ -34,9 +35,9 @@ private void initSecureStorageManager() { } /** Helper function for tests to go back to normal settings. */ - void resetSecureStorageManager() { + static void resetSecureStorageManager() { logger.debug("Resetting the secure storage manager"); - initSecureStorageManager(); + getInstance().initSecureStorageManager(); } /** @@ -44,9 +45,9 @@ void resetSecureStorageManager() { * * @param manager SecureStorageManager */ - void injectSecureStorageManager(SecureStorageManager manager) { + static void injectSecureStorageManager(SecureStorageManager manager) { logger.debug("Injecting secure storage manager"); - secureStorageManager = manager; + getInstance().secureStorageManager = manager; } private static class CredentialManagerHolder { @@ -67,7 +68,12 @@ static void fillCachedIdToken(SFLoginInput loginInput) throws SFException { "Looking for cached id token for user: {}, host: {}", loginInput.getUserName(), loginInput.getHostFromServerUrl()); - getInstance().fillCachedCredential(loginInput, CachedCredentialType.ID_TOKEN); + getInstance() + .fillCachedCredential( + loginInput, + loginInput.getHostFromServerUrl(), + loginInput.getUserName(), + CachedCredentialType.ID_TOKEN); } /** @@ -80,7 +86,12 @@ static void fillCachedMfaToken(SFLoginInput loginInput) throws SFException { "Looking for cached mfa token for user: {}, host: {}", loginInput.getUserName(), loginInput.getHostFromServerUrl()); - getInstance().fillCachedCredential(loginInput, CachedCredentialType.MFA_TOKEN); + getInstance() + .fillCachedCredential( + loginInput, + loginInput.getHostFromServerUrl(), + loginInput.getUserName(), + CachedCredentialType.MFA_TOKEN); } /** @@ -89,11 +100,14 @@ static void fillCachedMfaToken(SFLoginInput loginInput) throws SFException { * @param loginInput login input to attach access token */ static void fillCachedOAuthAccessToken(SFLoginInput loginInput) throws SFException { + String host = getHostForOAuthCacheKey(loginInput); logger.debug( "Looking for cached OAuth access token for user: {}, host: {}", loginInput.getUserName(), - loginInput.getHostFromServerUrl()); - getInstance().fillCachedCredential(loginInput, CachedCredentialType.OAUTH_ACCESS_TOKEN); + host); + getInstance() + .fillCachedCredential( + loginInput, host, loginInput.getUserName(), CachedCredentialType.OAUTH_ACCESS_TOKEN); } /** @@ -102,21 +116,19 @@ static void fillCachedOAuthAccessToken(SFLoginInput loginInput) throws SFExcepti * @param loginInput login input to attach refresh token */ static void fillCachedOAuthRefreshToken(SFLoginInput loginInput) throws SFException { + String host = getHostForOAuthCacheKey(loginInput); logger.debug( "Looking for cached OAuth refresh token for user: {}, host: {}", loginInput.getUserName(), - loginInput.getHostFromServerUrl()); - getInstance().fillCachedCredential(loginInput, CachedCredentialType.OAUTH_REFRESH_TOKEN); + host); + getInstance() + .fillCachedCredential( + loginInput, host, loginInput.getUserName(), CachedCredentialType.OAUTH_REFRESH_TOKEN); } - /** - * Reuse the cached token stored locally - * - * @param loginInput login input to attach token - * @param credType credential type to retrieve - */ - synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredentialType credType) - throws SFException { + /** Reuse the cached token stored locally */ + synchronized void fillCachedCredential( + SFLoginInput loginInput, String host, String username, CachedCredentialType credType) { if (secureStorageManager == null) { logMissingJnaJarForSecureLocalStorage(); return; @@ -124,9 +136,7 @@ synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredential String cred; try { - cred = - secureStorageManager.getCredential( - loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType.getValue()); + cred = secureStorageManager.getCredential(host, username, credType.getValue()); } catch (NoClassDefFoundError error) { logMissingJnaJarForSecureLocalStorage(); return; @@ -140,8 +150,8 @@ synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredential "Setting {}{} token for user: {}, host: {}", cred == null ? "null " : "", credType.getValue(), - loginInput.getUserName(), - loginInput.getHostFromServerUrl()); + username, + host); switch (credType) { case ID_TOKEN: loginInput.setIdToken(cred); @@ -161,36 +171,30 @@ synchronized void fillCachedCredential(SFLoginInput loginInput, CachedCredential } } - /** - * Store ID Token - * - * @param loginInput loginInput to denote to the cache - * @param loginOutput loginOutput to denote to the cache - */ - static void writeIdToken(SFLoginInput loginInput, SFLoginOutput loginOutput) throws SFException { + static void writeIdToken(SFLoginInput loginInput, String idToken) throws SFException { logger.debug( "Caching id token in a secure storage for user: {}, host: {}", loginInput.getUserName(), loginInput.getHostFromServerUrl()); getInstance() .writeTemporaryCredential( - loginInput, loginOutput.getIdToken(), CachedCredentialType.ID_TOKEN); + loginInput.getHostFromServerUrl(), + loginInput.getUserName(), + idToken, + CachedCredentialType.ID_TOKEN); } - /** - * Store MFA Token - * - * @param loginInput loginInput to denote to the cache - * @param loginOutput loginOutput to denote to the cache - */ - static void writeMfaToken(SFLoginInput loginInput, SFLoginOutput loginOutput) throws SFException { + static void writeMfaToken(SFLoginInput loginInput, String mfaToken) throws SFException { logger.debug( "Caching mfa token in a secure storage for user: {}, host: {}", loginInput.getUserName(), loginInput.getHostFromServerUrl()); getInstance() .writeTemporaryCredential( - loginInput, loginOutput.getMfaToken(), CachedCredentialType.MFA_TOKEN); + loginInput.getHostFromServerUrl(), + loginInput.getUserName(), + mfaToken, + CachedCredentialType.MFA_TOKEN); } /** @@ -199,13 +203,17 @@ static void writeMfaToken(SFLoginInput loginInput, SFLoginOutput loginOutput) th * @param loginInput loginInput to denote to the cache */ static void writeOAuthAccessToken(SFLoginInput loginInput) throws SFException { + String host = getHostForOAuthCacheKey(loginInput); logger.debug( "Caching OAuth access token in a secure storage for user: {}, host: {}", loginInput.getUserName(), - loginInput.getHostFromServerUrl()); + host); getInstance() .writeTemporaryCredential( - loginInput, loginInput.getOauthAccessToken(), CachedCredentialType.OAUTH_ACCESS_TOKEN); + host, + loginInput.getUserName(), + loginInput.getOauthAccessToken(), + CachedCredentialType.OAUTH_ACCESS_TOKEN); } /** @@ -214,26 +222,22 @@ static void writeOAuthAccessToken(SFLoginInput loginInput) throws SFException { * @param loginInput loginInput to denote to the cache */ static void writeOAuthRefreshToken(SFLoginInput loginInput) throws SFException { + String host = getHostForOAuthCacheKey(loginInput); logger.debug( "Caching OAuth refresh token in a secure storage for user: {}, host: {}", loginInput.getUserName(), - loginInput.getHostFromServerUrl()); + host); getInstance() .writeTemporaryCredential( - loginInput, + host, + loginInput.getUserName(), loginInput.getOauthRefreshToken(), CachedCredentialType.OAUTH_REFRESH_TOKEN); } - /** - * Store the temporary credential - * - * @param loginInput loginInput to denote to the cache - * @param cred the credential - * @param credType type of the credential - */ + /** Store the temporary credential */ synchronized void writeTemporaryCredential( - SFLoginInput loginInput, String cred, CachedCredentialType credType) throws SFException { + String host, String user, String cred, CachedCredentialType credType) { if (Strings.isNullOrEmpty(cred)) { logger.debug("No {} is given.", credType); return; // no credential @@ -245,8 +249,7 @@ synchronized void writeTemporaryCredential( } try { - secureStorageManager.setCredential( - loginInput.getHostFromServerUrl(), loginInput.getUserName(), credType.getValue(), cred); + secureStorageManager.setCredential(host, user, credType.getValue(), cred); } catch (NoClassDefFoundError error) { logMissingJnaJarForSecureLocalStorage(); } @@ -267,21 +270,41 @@ static void deleteMfaTokenCache(String host, String user) { } /** Delete the OAuth access token cache */ - static void deleteOAuthAccessTokenCache(String host, String user) { + static void deleteOAuthAccessTokenCache(SFLoginInput loginInput) throws SFException { + String host = getHostForOAuthCacheKey(loginInput); logger.debug( "Removing cached OAuth access token from a secure storage for user: {}, host: {}", - user, + loginInput.getUserName(), host); - getInstance().deleteTemporaryCredential(host, user, CachedCredentialType.OAUTH_ACCESS_TOKEN); + getInstance() + .deleteTemporaryCredential( + host, loginInput.getUserName(), CachedCredentialType.OAUTH_ACCESS_TOKEN); } /** Delete the OAuth refresh token cache */ - static void deleteOAuthRefreshTokenCache(String host, String user) { + static void deleteOAuthRefreshTokenCache(SFLoginInput loginInput) throws SFException { + String host = getHostForOAuthCacheKey(loginInput); logger.debug( "Removing cached OAuth refresh token from a secure storage for user: {}, host: {}", - user, + loginInput.getUserName(), host); - getInstance().deleteTemporaryCredential(host, user, CachedCredentialType.OAUTH_REFRESH_TOKEN); + getInstance() + .deleteTemporaryCredential( + host, loginInput.getUserName(), CachedCredentialType.OAUTH_REFRESH_TOKEN); + } + + /** + * Method required for OAuth token caching, since actual token is not Snowflake account-specific, + * but rather IdP-specific + */ + static String getHostForOAuthCacheKey(SFLoginInput loginInput) throws SFException { + String externalTokenRequestUrl = loginInput.getOauthLoginInput().getExternalTokenRequestUrl(); + if (externalTokenRequestUrl != null) { + URI parsedUrl = URI.create(externalTokenRequestUrl); + return parsedUrl.getHost(); + } else { + return loginInput.getHostFromServerUrl(); + } } /** diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index e7ec1d508..a00133b9b 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -328,7 +328,7 @@ static SFLoginOutput openSession( readCachedTokens(loginInput); if (OAuthAccessTokenProviderFactory.isEligible(getAuthenticator(loginInput))) { - updateInputWithOAuthAccessToken(loginInput); + obtainAuthAccessTokenAndUpdateInput(loginInput); } try { @@ -339,23 +339,24 @@ static SFLoginOutput openSession( refreshOAuthAccessTokenAndUpdateInput(loginInput); } else { loginInput.setAuthenticator(loginInput.getOriginAuthenticator()); - obtainOAuthAccessTokenAndUpdateInput(loginInput); + fetchOAuthAccessTokenAndUpdateInput(loginInput); } } return newSession(loginInput, connectionPropertiesMap, tracingLevel); } } - private static void updateInputWithOAuthAccessToken(SFLoginInput loginInput) throws SFException { - if (loginInput.getOauthAccessToken() != null) { // Access Token cached + private static void obtainAuthAccessTokenAndUpdateInput(SFLoginInput loginInput) + throws SFException { + if (loginInput.getOauthAccessToken() != null) { // Access Token was cached loginInput.setAuthenticator(AuthenticatorType.OAUTH.name()); loginInput.setToken(loginInput.getOauthAccessToken()); } else { // Access Token not cached - obtainOAuthAccessTokenAndUpdateInput(loginInput); + fetchOAuthAccessTokenAndUpdateInput(loginInput); } } - private static void obtainOAuthAccessTokenAndUpdateInput(SFLoginInput loginInput) + private static void fetchOAuthAccessTokenAndUpdateInput(SFLoginInput loginInput) throws SFException { OAuthAccessTokenProviderFactory accessTokenProviderFactory = new OAuthAccessTokenProviderFactory( @@ -387,9 +388,9 @@ private static void refreshOAuthAccessTokenAndUpdateInput(SFLoginInput loginInpu logger.debug( "Refreshing OAuth access token failed. Removing OAuth refresh token from cache and restarting OAuth flow...", e); - deleteOAuthRefreshTokenCache(loginInput.getHostFromServerUrl(), loginInput.getUserName()); + CredentialManager.deleteOAuthRefreshTokenCache(loginInput); loginInput.setAuthenticator(loginInput.getOriginAuthenticator()); - obtainOAuthAccessTokenAndUpdateInput(loginInput); + fetchOAuthAccessTokenAndUpdateInput(loginInput); } } @@ -860,9 +861,15 @@ static SFLoginOutput newSession( SnowflakeUtil.checkErrorAndThrowExceptionIncludingReauth(jsonNode); } + if (errorCode == Constants.OAUTH_ACCESS_TOKEN_INVALID_GS_CODE) { + logger.debug("OAuth Access Token Invalid: {}", errorCode); + loginInput.setOauthAccessToken(null); + CredentialManager.deleteOAuthAccessTokenCache(loginInput); + } + if (errorCode == Constants.OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE) { loginInput.setOauthAccessToken(null); - deleteOAuthAccessTokenCache(loginInput.getHostFromServerUrl(), loginInput.getUserName()); + CredentialManager.deleteOAuthAccessTokenCache(loginInput); logger.debug("OAuth Access Token Expired: {}", errorCode); SnowflakeUtil.checkErrorAndThrowExceptionIncludingReauth(jsonNode); @@ -1007,7 +1014,7 @@ static SFLoginOutput newSession( if (asBoolean(loginInput.getSessionParameters().get(CLIENT_STORE_TEMPORARY_CREDENTIAL))) { if (consentCacheIdToken) { - CredentialManager.writeIdToken(loginInput, ret); + CredentialManager.writeIdToken(loginInput, ret.getIdToken()); } if (loginInput.getOauthAccessToken() != null) { CredentialManager.writeOAuthAccessToken(loginInput); @@ -1018,7 +1025,7 @@ static SFLoginOutput newSession( } if (asBoolean(loginInput.getSessionParameters().get(CLIENT_REQUEST_MFA_TOKEN))) { - CredentialManager.writeMfaToken(loginInput, ret); + CredentialManager.writeMfaToken(loginInput, ret.getMfaToken()); } stopwatch.stop(); @@ -1065,14 +1072,6 @@ public static void deleteMfaTokenCache(String host, String user) { CredentialManager.deleteMfaTokenCache(host, user); } - private static void deleteOAuthAccessTokenCache(String host, String user) { - CredentialManager.deleteOAuthAccessTokenCache(host, user); - } - - private static void deleteOAuthRefreshTokenCache(String host, String user) { - CredentialManager.deleteOAuthRefreshTokenCache(host, user); - } - /** * Renew a session. * diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenForRefreshTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenForRefreshTokenProvider.java index 88a71571f..0d74bbbc6 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenForRefreshTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenForRefreshTokenProvider.java @@ -28,7 +28,7 @@ public class OAuthAccessTokenForRefreshTokenProvider implements AccessTokenProvi private static final SFLogger logger = SFLoggerFactory.getLogger(OAuthClientCredentialsAccessTokenProvider.class); - private final ObjectMapper objectMapper = new ObjectMapper(); + private static final ObjectMapper objectMapper = new ObjectMapper(); @Override public TokenResponseDTO getAccessToken(SFLoginInput loginInput) throws SFException { diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java index 5f9363cc5..3491b9593 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java @@ -44,7 +44,7 @@ public AccessTokenProvider createAccessTokenProvider( case OAUTH_AUTHORIZATION_CODE: assertContainsClientCredentials(loginInput, authenticatorType); return new OAuthAuthorizationCodeAccessTokenProvider( - browserHandler, browserAuthorizationTimeoutSeconds); + browserHandler, new RandomStateProvider(), browserAuthorizationTimeoutSeconds); case OAUTH_CLIENT_CREDENTIALS: assertContainsClientCredentials(loginInput, authenticatorType); AssertUtil.assertTrue( diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java index 8f296b8b1..2b08303c6 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java @@ -52,14 +52,18 @@ public class OAuthAuthorizationCodeAccessTokenProvider implements AccessTokenPro private static final String DEFAULT_REDIRECT_HOST = "http://localhost:8001"; private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect"; private static final String DEFAULT_REDIRECT_URI = DEFAULT_REDIRECT_HOST + REDIRECT_URI_ENDPOINT; + private static final ObjectMapper objectMapper = new ObjectMapper(); private final AuthExternalBrowserHandlers browserHandler; - private final ObjectMapper objectMapper = new ObjectMapper(); + private final StateProvider stateProvider; private final int browserAuthorizationTimeoutSeconds; public OAuthAuthorizationCodeAccessTokenProvider( - AuthExternalBrowserHandlers browserHandler, int browserAuthorizationTimeoutSeconds) { + AuthExternalBrowserHandlers browserHandler, + StateProvider stateProvider, + int browserAuthorizationTimeoutSeconds) { this.browserHandler = browserHandler; + this.stateProvider = stateProvider; this.browserAuthorizationTimeoutSeconds = browserAuthorizationTimeoutSeconds; } @@ -80,11 +84,13 @@ public TokenResponseDTO getAccessToken(SFLoginInput loginInput) throws SFExcepti private AuthorizationCode requestAuthorizationCode( SFLoginInput loginInput, CodeVerifier pkceVerifier) throws SFException, IOException { - AuthorizationRequest request = buildAuthorizationRequest(loginInput, pkceVerifier); + State state = new State(stateProvider.getState()); + AuthorizationRequest request = buildAuthorizationRequest(loginInput, pkceVerifier, state); SFOauthLoginInput oauthLoginInput = loginInput.getOauthLoginInput(); URI authorizeRequestURI = request.toURI(); HttpServer httpServer = createHttpServer(oauthLoginInput); - CompletableFuture codeFuture = setupRedirectURIServerForAuthorizationCode(httpServer); + CompletableFuture codeFuture = + setupRedirectURIServerForAuthorizationCode(httpServer, state); logger.debug( "Waiting for authorization code redirection to {}...", buildRedirectUri(oauthLoginInput)); return letUserAuthorize(authorizeRequestURI, codeFuture, httpServer); @@ -144,7 +150,7 @@ private AuthorizationCode letUserAuthorize( } private static CompletableFuture setupRedirectURIServerForAuthorizationCode( - HttpServer httpServer) { + HttpServer httpServer, State expectedState) { CompletableFuture accessTokenFuture = new CompletableFuture<>(); httpServer.createContext( REDIRECT_URI_ENDPOINT, @@ -152,21 +158,7 @@ private static CompletableFuture setupRedirectURIServerForAuthorizationC Map urlParams = URLEncodedUtils.parse(exchange.getRequestURI(), StandardCharsets.UTF_8).stream() .collect(Collectors.toMap(NameValuePair::getName, NameValuePair::getValue)); - if (urlParams.containsKey("error")) { - accessTokenFuture.completeExceptionally( - new SFException( - ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, - String.format( - "Error during authorization: %s, %s", - urlParams.get("error"), urlParams.get("error_description")))); - } else { - String authorizationCode = urlParams.get("code"); - if (!StringUtils.isNullOrEmpty(authorizationCode)) { - logger.debug("Received authorization code on redirect URI"); - accessTokenFuture.complete(authorizationCode); - } - } - String response = "Authorization completed successfully."; + String response = handleRedirectRequest(urlParams, accessTokenFuture, expectedState); exchange.sendResponseHeaders(200, response.length()); exchange.getResponseBody().write(response.getBytes(StandardCharsets.UTF_8)); exchange.getResponseBody().close(); @@ -176,6 +168,40 @@ private static CompletableFuture setupRedirectURIServerForAuthorizationC return accessTokenFuture; } + private static String handleRedirectRequest( + Map urlParams, + CompletableFuture accessTokenFuture, + State expectedState) { + String response; + if (urlParams.containsKey("error")) { + response = "Authorization error: " + urlParams.get("error"); + accessTokenFuture.completeExceptionally( + new SFException( + ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, + String.format( + "Error during authorization: %s, %s", + urlParams.get("error"), urlParams.get("error_description")))); + } else if (!expectedState.getValue().equals(urlParams.get("state"))) { + accessTokenFuture.completeExceptionally( + new SFException( + ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, + String.format( + "Invalid authorization request redirection state: %s, expected: %s", + urlParams.get("state"), expectedState.getValue()))); + response = "Authorization error: invalid authorization request redirection state"; + } else { + String authorizationCode = urlParams.get("code"); + if (!StringUtils.isNullOrEmpty(authorizationCode)) { + logger.debug("Received authorization code on redirect URI"); + response = "Authorization completed successfully."; + accessTokenFuture.complete(authorizationCode); + } else { + response = "Authorization error: authorization code has not been returned to the driver."; + } + } + return response; + } + private static HttpServer createHttpServer(SFOauthLoginInput loginInput) throws IOException { URI redirectUri = buildRedirectUri(loginInput); return HttpServer.create( @@ -183,11 +209,10 @@ private static HttpServer createHttpServer(SFOauthLoginInput loginInput) throws } private static AuthorizationRequest buildAuthorizationRequest( - SFLoginInput loginInput, CodeVerifier pkceVerifier) { + SFLoginInput loginInput, CodeVerifier pkceVerifier, State state) { SFOauthLoginInput oauthLoginInput = loginInput.getOauthLoginInput(); ClientID clientID = new ClientID(oauthLoginInput.getClientId()); URI callback = buildRedirectUri(oauthLoginInput); - State state = new State(256); String scope = OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole()); return new AuthorizationRequest.Builder(new ResponseType(ResponseType.Value.CODE), clientID) .scope(new Scope(scope)) diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java index 0f53ef13d..de931b850 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java @@ -27,7 +27,7 @@ public class OAuthClientCredentialsAccessTokenProvider implements AccessTokenPro private static final SFLogger logger = SFLoggerFactory.getLogger(OAuthClientCredentialsAccessTokenProvider.class); - private final ObjectMapper objectMapper = new ObjectMapper(); + private static final ObjectMapper objectMapper = new ObjectMapper(); @Override public TokenResponseDTO getAccessToken(SFLoginInput loginInput) throws SFException { diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/RandomStateProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/RandomStateProvider.java new file mode 100644 index 000000000..a1ddee884 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/RandomStateProvider.java @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core.auth.oauth; + +import com.nimbusds.oauth2.sdk.id.State; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; + +@SnowflakeJdbcInternalApi +public class RandomStateProvider implements StateProvider { + @Override + public String getState() { + return new State(256).getValue(); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/StateProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/StateProvider.java new file mode 100644 index 000000000..6dcad2548 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/StateProvider.java @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core.auth.oauth; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; + +@SnowflakeJdbcInternalApi +public interface StateProvider { + String getState(); +} diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index cbb536926..0293aaddd 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -325,8 +325,6 @@ public static Connection getConnection( properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? properties.put("insecureMode", false); // use OCSP for all tests. - properties.put("authenticator", AuthenticatorType.OAUTH_AUTHORIZATION_CODE.name()); - if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout)); } diff --git a/src/test/java/net/snowflake/client/core/CredentialManagerTest.java b/src/test/java/net/snowflake/client/core/CredentialManagerTest.java new file mode 100644 index 000000000..c622692a0 --- /dev/null +++ b/src/test/java/net/snowflake/client/core/CredentialManagerTest.java @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class CredentialManagerTest { + + public static final String SNOWFLAKE_HOST = "some-account.us-west-2.aws.snowflakecomputing.com"; + public static final String EXTERNAL_OAUTH_HOST = "some-external-oauth-host.com"; + public static final String SOME_ACCESS_TOKEN = "some-oauth-access-token"; + public static final String SOME_REFRESH_TOKEN = "some-refresh-token"; + public static final String SOME_ID_TOKEN_FROM_CACHE = "some-id-token"; + public static final String SOME_MFA_TOKEN_FROM_CACHE = "some-mfa-token"; + public static final String SOME_USER = "some-user"; + + public static final String ACCESS_TOKEN_FROM_CACHE = "access-token-from-cache"; + public static final String REFRESH_TOKEN_FROM_CACHE = "refresh-token-from-cache"; + public static final String EXTERNAL_ACCESS_TOKEN_FROM_CACHE = "external-access-token-from-cache"; + public static final String EXTERNAL_REFRESH_TOKEN_FROM_CACHE = + "external-refresh-token-from-cache"; + + private static final SecureStorageManager mockSecureStorageManager = + mock(SecureStorageManager.class); + + @BeforeAll + public static void init() { + CredentialManager.injectSecureStorageManager(mockSecureStorageManager); + } + + @Test + public void shouldCreateHostBasedOnExternalIdpUrl() throws SFException { + SFLoginInput loginInput = createLoginInputWithExternalOAuth(); + String host = CredentialManager.getHostForOAuthCacheKey(loginInput); + assertEquals(EXTERNAL_OAUTH_HOST, host); + } + + @Test + public void shouldCreateHostBasedOnSnowflakeServerUrl() throws SFException { + SFLoginInput loginInput = createLoginInputWithSnowflakeServer(); + String host = CredentialManager.getHostForOAuthCacheKey(loginInput); + assertEquals(SNOWFLAKE_HOST, host); + } + + @Test + public void shouldProperlyWriteTokensToCache() throws SFException { + SFLoginInput loginInputSnowflakeOAuth = createLoginInputWithSnowflakeServer(); + CredentialManager.writeIdToken(loginInputSnowflakeOAuth, SOME_ID_TOKEN_FROM_CACHE); + verify(mockSecureStorageManager, times(1)) + .setCredential( + SNOWFLAKE_HOST, + SOME_USER, + CachedCredentialType.ID_TOKEN.getValue(), + SOME_ID_TOKEN_FROM_CACHE); + CredentialManager.writeMfaToken(loginInputSnowflakeOAuth, SOME_MFA_TOKEN_FROM_CACHE); + verify(mockSecureStorageManager, times(1)) + .setCredential( + SNOWFLAKE_HOST, + SOME_USER, + CachedCredentialType.MFA_TOKEN.getValue(), + SOME_MFA_TOKEN_FROM_CACHE); + + CredentialManager.writeOAuthAccessToken(loginInputSnowflakeOAuth); + verify(mockSecureStorageManager, times(1)) + .setCredential( + SNOWFLAKE_HOST, + SOME_USER, + CachedCredentialType.OAUTH_ACCESS_TOKEN.getValue(), + SOME_ACCESS_TOKEN); + CredentialManager.writeOAuthRefreshToken(loginInputSnowflakeOAuth); + verify(mockSecureStorageManager, times(1)) + .setCredential( + SNOWFLAKE_HOST, + SOME_USER, + CachedCredentialType.OAUTH_REFRESH_TOKEN.getValue(), + SOME_REFRESH_TOKEN); + + SFLoginInput loginInputExternalOAuth = createLoginInputWithExternalOAuth(); + CredentialManager.writeOAuthAccessToken(loginInputExternalOAuth); + verify(mockSecureStorageManager, times(1)) + .setCredential( + EXTERNAL_OAUTH_HOST, + SOME_USER, + CachedCredentialType.OAUTH_ACCESS_TOKEN.getValue(), + SOME_ACCESS_TOKEN); + CredentialManager.writeOAuthRefreshToken(loginInputExternalOAuth); + verify(mockSecureStorageManager, times(1)) + .setCredential( + EXTERNAL_OAUTH_HOST, + SOME_USER, + CachedCredentialType.OAUTH_REFRESH_TOKEN.getValue(), + SOME_REFRESH_TOKEN); + } + + @Test + public void shouldProperlyDeleteTokensFromCache() throws SFException { + SFLoginInput loginInputSnowflakeOAuth = createLoginInputWithSnowflakeServer(); + CredentialManager.deleteIdTokenCache( + loginInputSnowflakeOAuth.getHostFromServerUrl(), loginInputSnowflakeOAuth.getUserName()); + verify(mockSecureStorageManager, times(1)) + .deleteCredential(SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.ID_TOKEN.getValue()); + CredentialManager.deleteMfaTokenCache( + loginInputSnowflakeOAuth.getHostFromServerUrl(), loginInputSnowflakeOAuth.getUserName()); + verify(mockSecureStorageManager, times(1)) + .deleteCredential(SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.MFA_TOKEN.getValue()); + + CredentialManager.deleteOAuthAccessTokenCache(loginInputSnowflakeOAuth); + verify(mockSecureStorageManager, times(1)) + .deleteCredential( + SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.OAUTH_ACCESS_TOKEN.getValue()); + CredentialManager.deleteOAuthRefreshTokenCache(loginInputSnowflakeOAuth); + verify(mockSecureStorageManager, times(1)) + .deleteCredential( + SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.OAUTH_REFRESH_TOKEN.getValue()); + + SFLoginInput loginInputExternalOAuth = createLoginInputWithExternalOAuth(); + CredentialManager.deleteOAuthAccessTokenCache(loginInputExternalOAuth); + verify(mockSecureStorageManager, times(1)) + .deleteCredential( + EXTERNAL_OAUTH_HOST, SOME_USER, CachedCredentialType.OAUTH_ACCESS_TOKEN.getValue()); + CredentialManager.deleteOAuthRefreshTokenCache(loginInputExternalOAuth); + verify(mockSecureStorageManager, times(1)) + .deleteCredential( + EXTERNAL_OAUTH_HOST, SOME_USER, CachedCredentialType.OAUTH_REFRESH_TOKEN.getValue()); + } + + @Test + public void shouldProperlyUpdateInputWithTokensFromCache() throws SFException { + SFLoginInput loginInputSnowflakeOAuth = createLoginInputWithSnowflakeServer(); + when(mockSecureStorageManager.getCredential( + SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.ID_TOKEN.getValue())) + .thenReturn(SOME_ID_TOKEN_FROM_CACHE); + CredentialManager.fillCachedIdToken(loginInputSnowflakeOAuth); + when(mockSecureStorageManager.getCredential( + SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.MFA_TOKEN.getValue())) + .thenReturn(SOME_MFA_TOKEN_FROM_CACHE); + CredentialManager.fillCachedMfaToken(loginInputSnowflakeOAuth); + assertEquals(SOME_ID_TOKEN_FROM_CACHE, loginInputSnowflakeOAuth.getIdToken()); + assertEquals(SOME_MFA_TOKEN_FROM_CACHE, loginInputSnowflakeOAuth.getMfaToken()); + + when(mockSecureStorageManager.getCredential( + SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.OAUTH_ACCESS_TOKEN.getValue())) + .thenReturn(ACCESS_TOKEN_FROM_CACHE); + CredentialManager.fillCachedOAuthAccessToken(loginInputSnowflakeOAuth); + when(mockSecureStorageManager.getCredential( + SNOWFLAKE_HOST, SOME_USER, CachedCredentialType.OAUTH_REFRESH_TOKEN.getValue())) + .thenReturn(REFRESH_TOKEN_FROM_CACHE); + CredentialManager.fillCachedOAuthRefreshToken(loginInputSnowflakeOAuth); + assertEquals(ACCESS_TOKEN_FROM_CACHE, loginInputSnowflakeOAuth.getOauthAccessToken()); + assertEquals(REFRESH_TOKEN_FROM_CACHE, loginInputSnowflakeOAuth.getOauthRefreshToken()); + + SFLoginInput loginInputExternalOAuth = createLoginInputWithExternalOAuth(); + when(mockSecureStorageManager.getCredential( + EXTERNAL_OAUTH_HOST, SOME_USER, CachedCredentialType.OAUTH_ACCESS_TOKEN.getValue())) + .thenReturn(EXTERNAL_ACCESS_TOKEN_FROM_CACHE); + CredentialManager.fillCachedOAuthAccessToken(loginInputExternalOAuth); + when(mockSecureStorageManager.getCredential( + EXTERNAL_OAUTH_HOST, SOME_USER, CachedCredentialType.OAUTH_REFRESH_TOKEN.getValue())) + .thenReturn(EXTERNAL_REFRESH_TOKEN_FROM_CACHE); + CredentialManager.fillCachedOAuthRefreshToken(loginInputExternalOAuth); + assertEquals(EXTERNAL_ACCESS_TOKEN_FROM_CACHE, loginInputExternalOAuth.getOauthAccessToken()); + assertEquals(EXTERNAL_REFRESH_TOKEN_FROM_CACHE, loginInputExternalOAuth.getOauthRefreshToken()); + } + + private SFLoginInput createLoginInputWithExternalOAuth() { + SFLoginInput loginInput = createGenericLoginInput(); + loginInput.setOauthLoginInput( + new SFOauthLoginInput( + null, null, null, null, "https://some-external-oauth-host.com/oauth/token", null)); + return loginInput; + } + + private SFLoginInput createLoginInputWithSnowflakeServer() { + SFLoginInput loginInput = createGenericLoginInput(); + loginInput.setOauthLoginInput(new SFOauthLoginInput(null, null, null, null, null, null)); + loginInput.setServerUrl("https://some-account.us-west-2.aws.snowflakecomputing.com:443/"); + + return loginInput; + } + + private SFLoginInput createGenericLoginInput() { + SFLoginInput loginInput = new SFLoginInput(); + loginInput.setOauthAccessToken(SOME_ACCESS_TOKEN); + loginInput.setOauthRefreshToken(SOME_REFRESH_TOKEN); + loginInput.setUserName(SOME_USER); + return loginInput; + } +} diff --git a/src/test/java/net/snowflake/client/core/OAuthAuthorizationCodeFlowLatestIT.java b/src/test/java/net/snowflake/client/core/OAuthAuthorizationCodeFlowLatestIT.java index 86e9eb84d..d45eafa4b 100644 --- a/src/test/java/net/snowflake/client/core/OAuthAuthorizationCodeFlowLatestIT.java +++ b/src/test/java/net/snowflake/client/core/OAuthAuthorizationCodeFlowLatestIT.java @@ -12,6 +12,7 @@ import net.snowflake.client.category.TestTags; import net.snowflake.client.core.auth.oauth.AccessTokenProvider; import net.snowflake.client.core.auth.oauth.OAuthAuthorizationCodeAccessTokenProvider; +import net.snowflake.client.core.auth.oauth.StateProvider; import net.snowflake.client.core.auth.oauth.TokenResponseDTO; import net.snowflake.client.jdbc.BaseWiremockTest; import org.apache.http.HttpResponse; @@ -35,6 +36,8 @@ public class OAuthAuthorizationCodeFlowLatestIT extends BaseWiremockTest { SCENARIOS_BASE_DIR + "/browser_timeout_authorization_error.json"; private static final String INVALID_SCOPE_SCENARIO_MAPPING = SCENARIOS_BASE_DIR + "/invalid_scope_error.json"; + private static final String INVALID_STATE_SCENARIO_MAPPING = + SCENARIOS_BASE_DIR + "/invalid_state_error.json"; private static final String TOKEN_REQUEST_ERROR_SCENARIO_MAPPING = SCENARIOS_BASE_DIR + "/token_request_error.json"; private static final String CUSTOM_URLS_SCENARIO_MAPPINGS = @@ -46,14 +49,16 @@ public class OAuthAuthorizationCodeFlowLatestIT extends BaseWiremockTest { private final AuthExternalBrowserHandlers wiremockProxyRequestBrowserHandler = new WiremockProxyRequestBrowserHandler(); + private final AccessTokenProvider provider = + new OAuthAuthorizationCodeAccessTokenProvider( + wiremockProxyRequestBrowserHandler, new MockStateProvider(), 30); + @Test public void successfulFlowScenario() throws SFException { importMappingFromResources(SUCCESSFUL_FLOW_SCENARIO_MAPPINGS); SFLoginInput loginInput = createLoginInputStub("http://localhost:8009/snowflake/oauth-redirect", null, null); - AccessTokenProvider provider = - new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); TokenResponseDTO tokenResponse = provider.getAccessToken(loginInput); String accessToken = tokenResponse.getAccessToken(); @@ -70,8 +75,6 @@ public void customUrlsScenario() throws SFException { String.format("http://%s:%d/authorization", WIREMOCK_HOST, wiremockHttpPort), String.format("http://%s:%d/tokenrequest", WIREMOCK_HOST, wiremockHttpPort)); - AccessTokenProvider provider = - new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); TokenResponseDTO tokenResponse = provider.getAccessToken(loginInput); String accessToken = tokenResponse.getAccessToken(); @@ -86,7 +89,8 @@ public void browserTimeoutFlowScenario() { createLoginInputStub("http://localhost:8004/snowflake/oauth-redirect", null, null); AccessTokenProvider provider = - new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 1); + new OAuthAuthorizationCodeAccessTokenProvider( + wiremockProxyRequestBrowserHandler, new MockStateProvider(), 1); SFException e = Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( @@ -100,8 +104,6 @@ public void invalidScopeFlowScenario() { importMappingFromResources(INVALID_SCOPE_SCENARIO_MAPPING); SFLoginInput loginInput = createLoginInputStub("http://localhost:8002/snowflake/oauth-redirect", null, null); - AccessTokenProvider provider = - new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); SFException e = Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( @@ -110,14 +112,25 @@ public void invalidScopeFlowScenario() { "Error during authorization: invalid_scope, One or more scopes are not configured for the authorization server resource.")); } + @Test + public void invalidStateFlowScenario() { + importMappingFromResources(INVALID_STATE_SCENARIO_MAPPING); + SFLoginInput loginInput = + createLoginInputStub("http://localhost:8010/snowflake/oauth-redirect", null, null); + SFException e = + Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); + Assertions.assertTrue( + e.getMessage() + .contains( + "Error during OAuth Authorization Code authentication: Invalid authorization request redirection state: invalidstate, expected: abc123")); + } + @Test public void tokenRequestErrorFlowScenario() { importMappingFromResources(TOKEN_REQUEST_ERROR_SCENARIO_MAPPING); SFLoginInput loginInput = createLoginInputStub("http://localhost:8003/snowflake/oauth-redirect", null, null); - AccessTokenProvider provider = - new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); SFException e = Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( @@ -168,4 +181,12 @@ public void output(String msg) { // do nothing } } + + static class MockStateProvider implements StateProvider { + + @Override + public String getState() { + return "abc123"; + } + } } diff --git a/src/test/java/net/snowflake/client/core/OAuthTokenCacheLatestIT.java b/src/test/java/net/snowflake/client/core/OAuthTokenCacheLatestIT.java index d91fd0ac3..27ab424f6 100644 --- a/src/test/java/net/snowflake/client/core/OAuthTokenCacheLatestIT.java +++ b/src/test/java/net/snowflake/client/core/OAuthTokenCacheLatestIT.java @@ -9,14 +9,17 @@ import java.time.Duration; import java.util.Collections; import java.util.HashMap; +import net.snowflake.client.category.TestTags; import net.snowflake.client.core.auth.AuthenticatorType; import net.snowflake.client.jdbc.BaseWiremockTest; import net.snowflake.client.jdbc.SnowflakeSQLException; +import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.stubbing.Answer; +@Tag(TestTags.CORE) public class OAuthTokenCacheLatestIT extends BaseWiremockTest { private static final String SCENARIOS_BASE_DIR = MAPPINGS_BASE_DIR + "/oauth/token_caching"; @@ -73,15 +76,9 @@ public void shouldRefreshExpiredAccessTokenAndConnectSuccessfully() SFLoginOutput loginOutput = SessionUtil.openSession(loginInput, new HashMap<>(), "INFO"); credentialManagerMockedStatic.verify( - () -> - CredentialManager.deleteOAuthAccessTokenCache( - loginInput.getHostFromServerUrl(), loginInput.getUserName()), - times(1)); + () -> CredentialManager.deleteOAuthAccessTokenCache(loginInput), times(1)); credentialManagerMockedStatic.verify( - () -> - CredentialManager.deleteOAuthRefreshTokenCache( - loginInput.getHostFromServerUrl(), loginInput.getUserName()), - never()); + () -> CredentialManager.deleteOAuthRefreshTokenCache(loginInput), never()); assertEquals("new-refreshed-access-token-123", loginOutput.getOauthAccessToken()); captureAndAssertSavedTokenValues( @@ -105,10 +102,7 @@ public void shouldCacheRefreshedAccessTokenAndNewRefreshToken() SFLoginOutput loginOutput = SessionUtil.openSession(loginInput, new HashMap<>(), "INFO"); credentialManagerMockedStatic.verify( - () -> - CredentialManager.deleteOAuthAccessTokenCache( - loginInput.getHostFromServerUrl(), loginInput.getUserName()), - times(1)); + () -> CredentialManager.deleteOAuthAccessTokenCache(loginInput), times(1)); assertEquals("new-refreshed-access-token-123", loginOutput.getOauthAccessToken()); captureAndAssertSavedTokenValues( credentialManagerMockedStatic, "new-refreshed-access-token-123", "new-refresh-token-123"); @@ -127,15 +121,9 @@ public void shouldRestartFullFlowOnAccessTokenExpirationAndErrorWhenRefreshing() SFLoginOutput loginOutput = SessionUtil.openSession(loginInput, new HashMap<>(), "INFO"); credentialManagerMockedStatic.verify( - () -> - CredentialManager.deleteOAuthAccessTokenCache( - loginInput.getHostFromServerUrl(), loginInput.getUserName()), - times(1)); + () -> CredentialManager.deleteOAuthAccessTokenCache(loginInput), times(1)); credentialManagerMockedStatic.verify( - () -> - CredentialManager.deleteOAuthRefreshTokenCache( - loginInput.getHostFromServerUrl(), loginInput.getUserName()), - times(1)); + () -> CredentialManager.deleteOAuthRefreshTokenCache(loginInput), times(1)); assertEquals("newly-obtained-access-token-123", loginOutput.getOauthAccessToken()); captureAndAssertSavedTokenValues( credentialManagerMockedStatic, @@ -156,15 +144,9 @@ public void shouldRestartFullFlowOnAccessTokenExpirationAndNoRefreshToken() SFLoginOutput loginOutput = SessionUtil.openSession(loginInput, new HashMap<>(), "INFO"); credentialManagerMockedStatic.verify( - () -> - CredentialManager.deleteOAuthAccessTokenCache( - loginInput.getHostFromServerUrl(), loginInput.getUserName()), - times(1)); + () -> CredentialManager.deleteOAuthAccessTokenCache(loginInput), times(1)); credentialManagerMockedStatic.verify( - () -> - CredentialManager.deleteOAuthRefreshTokenCache( - loginInput.getHostFromServerUrl(), loginInput.getUserName()), - never()); + () -> CredentialManager.deleteOAuthRefreshTokenCache(loginInput), never()); assertEquals("newly-obtained-access-token-123", loginOutput.getOauthAccessToken()); captureAndAssertSavedTokenValues( credentialManagerMockedStatic, "newly-obtained-access-token-123", null); diff --git a/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java b/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java index 0524ab6b8..88451aa0e 100644 --- a/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java +++ b/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java @@ -317,7 +317,7 @@ private void testUnavailableLSSWindowsHelper() throws SQLException { try { SecureStorageWindowsManager.Advapi32LibManager.setInstance(new MockUnavailableAdvapi32Lib()); SecureStorageWindowsManager manager = SecureStorageWindowsManager.builder(); - CredentialManager.getInstance().injectSecureStorageManager(manager); + CredentialManager.injectSecureStorageManager(manager); unavailableLSSWindowsTestBody(); } finally { SecureStorageWindowsManager.Advapi32LibManager.resetInstance(); @@ -329,7 +329,7 @@ public void testUnavailableLocalSecureStorage() throws SQLException { try { testUnavailableLSSWindowsHelper(); } finally { - CredentialManager.getInstance().resetSecureStorageManager(); + CredentialManager.resetSecureStorageManager(); } } diff --git a/src/test/resources/wiremock/mappings/oauth/authorization_code/external_idp_custom_urls.json b/src/test/resources/wiremock/mappings/oauth/authorization_code/external_idp_custom_urls.json index 8b69dae51..c03129042 100644 --- a/src/test/resources/wiremock/mappings/oauth/authorization_code/external_idp_custom_urls.json +++ b/src/test/resources/wiremock/mappings/oauth/authorization_code/external_idp_custom_urls.json @@ -34,7 +34,7 @@ "response": { "status": 302, "headers": { - "Location": "http://localhost:8007/snowflake/oauth-redirect?code=123" + "Location": "http://localhost:8007/snowflake/oauth-redirect?code=123&state=abc123" } } }, diff --git a/src/test/resources/wiremock/mappings/oauth/authorization_code/invalid_state_error.json b/src/test/resources/wiremock/mappings/oauth/authorization_code/invalid_state_error.json new file mode 100644 index 000000000..fb19c1cba --- /dev/null +++ b/src/test/resources/wiremock/mappings/oauth/authorization_code/invalid_state_error.json @@ -0,0 +1,17 @@ +{ + "mappings": [ + { + "scenarioName": "Invalid scope authorization error", + "request": { + "urlPathPattern": "/oauth/authorize.*", + "method": "GET" + }, + "response": { + "status": 302, + "headers": { + "Location": "http://localhost:8010/snowflake/oauth-redirect?code=123&state=invalidstate" + } + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/wiremock/mappings/oauth/authorization_code/successful_flow.json b/src/test/resources/wiremock/mappings/oauth/authorization_code/successful_flow.json index ec7d938ac..f49b224b4 100644 --- a/src/test/resources/wiremock/mappings/oauth/authorization_code/successful_flow.json +++ b/src/test/resources/wiremock/mappings/oauth/authorization_code/successful_flow.json @@ -34,7 +34,7 @@ "response": { "status": 302, "headers": { - "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123" + "Location": "http://localhost:8009/snowflake/oauth-redirect?code=123&state=abc123" } } }, diff --git a/src/test/resources/wiremock/mappings/oauth/authorization_code/token_request_error.json b/src/test/resources/wiremock/mappings/oauth/authorization_code/token_request_error.json index 5c5e73ae1..792e1e363 100644 --- a/src/test/resources/wiremock/mappings/oauth/authorization_code/token_request_error.json +++ b/src/test/resources/wiremock/mappings/oauth/authorization_code/token_request_error.json @@ -11,7 +11,7 @@ "response": { "status": 302, "headers": { - "Location": "http://localhost:8003/snowflake/oauth-redirect?code=123" + "Location": "http://localhost:8003/snowflake/oauth-redirect?code=123&state=abc123" } } },