From cbb581027f963830016abf07ca0bb7fe29a645a2 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Fri, 6 Dec 2024 16:11:33 +0100 Subject: [PATCH 1/7] Initial implementation of client credentials flow --- .../net/snowflake/client/core/AssertUtil.java | 2 +- .../net/snowflake/client/core/HttpUtil.java | 7 ++ .../net/snowflake/client/core/SFSession.java | 4 +- .../snowflake/client/core/SessionUtil.java | 26 +++---- .../client/core/auth/AuthenticatorType.java | 7 +- ...Provider.java => AccessTokenProvider.java} | 2 +- .../oauth/AccessTokenProviderFactory.java | 67 +++++++++++++++++ ...AuthorizationCodeAccessTokenProvider.java} | 58 ++++----------- ...hClientCredentialsAccessTokenProvider.java | 72 +++++++++++++++++++ .../client/core/auth/oauth/OAuthUtil.java | 46 ++++++++++++ .../net/snowflake/client/jdbc/ErrorCode.java | 3 +- .../jdbc/jdbc_error_messages.properties | 4 +- .../snowflake/client/AbstractDriverIT.java | 3 + .../OauthAuthorizationCodeFlowLatestIT.java | 24 +++---- 14 files changed, 250 insertions(+), 75 deletions(-) rename src/main/java/net/snowflake/client/core/auth/oauth/{OauthAccessTokenProvider.java => AccessTokenProvider.java} (87%) create mode 100644 src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java rename src/main/java/net/snowflake/client/core/auth/oauth/{AuthorizationCodeFlowAccessTokenProvider.java => OAuthAuthorizationCodeAccessTokenProvider.java} (79%) create mode 100644 src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java create mode 100644 src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java diff --git a/src/main/java/net/snowflake/client/core/AssertUtil.java b/src/main/java/net/snowflake/client/core/AssertUtil.java index c24139fa9..91237cf3b 100644 --- a/src/main/java/net/snowflake/client/core/AssertUtil.java +++ b/src/main/java/net/snowflake/client/core/AssertUtil.java @@ -16,7 +16,7 @@ public class AssertUtil { * @param internalErrorMesg The error message to display if condition is false * @throws SFException Will be thrown if condition is false */ - static void assertTrue(boolean condition, String internalErrorMesg) throws SFException { + public static void assertTrue(boolean condition, String internalErrorMesg) throws SFException { if (!condition) { throw new SFException(ErrorCode.INTERNAL_ERROR, internalErrorMesg); } diff --git a/src/main/java/net/snowflake/client/core/HttpUtil.java b/src/main/java/net/snowflake/client/core/HttpUtil.java index 23b83df09..7a6b92ada 100644 --- a/src/main/java/net/snowflake/client/core/HttpUtil.java +++ b/src/main/java/net/snowflake/client/core/HttpUtil.java @@ -836,12 +836,19 @@ private static String executeRequestInternal( stopwatch.stop(); } + writer = new StringWriter(); + try (InputStream ins = response.getEntity().getContent()) { + IOUtils.copy(ins, writer, "UTF-8"); + } + theString = writer.toString(); + if (response == null || response.getStatusLine().getStatusCode() != 200) { logger.error("Error executing request: {}", requestInfoScrubbed); SnowflakeUtil.logResponseDetails(response, logger); if (response != null) { + EntityUtils.consume(response.getEntity()); } diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index c892dfdc7..b63d1c6dd 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -645,9 +645,9 @@ public synchronized void open() throws SFException, SnowflakeSQLException { (String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_ID), (String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_SECRET), (String) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI), - (String) connectionPropertiesMap.get(SFSessionProperty.OAUTH_SCOPE), (String) connectionPropertiesMap.get(SFSessionProperty.EXTERNAL_AUTHORIZATION_URL), - (String) connectionPropertiesMap.get(SFSessionProperty.EXTERNAL_TOKEN_REQUEST_URL)); + (String) connectionPropertiesMap.get(SFSessionProperty.EXTERNAL_TOKEN_REQUEST_URL), + (String) connectionPropertiesMap.get(SFSessionProperty.OAUTH_SCOPE)); loginInput .setServerUrl((String) connectionPropertiesMap.get(SFSessionProperty.SERVER_URL)) diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index 1a26d1d86..dafbdf7fd 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -28,8 +28,8 @@ import net.snowflake.client.core.auth.AuthenticatorType; import net.snowflake.client.core.auth.ClientAuthnDTO; import net.snowflake.client.core.auth.ClientAuthnParameter; -import net.snowflake.client.core.auth.oauth.AuthorizationCodeFlowAccessTokenProvider; -import net.snowflake.client.core.auth.oauth.OauthAccessTokenProvider; +import net.snowflake.client.core.auth.oauth.AccessTokenProvider; +import net.snowflake.client.core.auth.oauth.AccessTokenProviderFactory; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.SnowflakeDriver; import net.snowflake.client.jdbc.SnowflakeReauthenticationRequest; @@ -225,6 +225,11 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) { .equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE.name())) { // OAuth authorization code flow authentication return AuthenticatorType.OAUTH_AUTHORIZATION_CODE; + } else if (loginInput + .getAuthenticator() + .equalsIgnoreCase(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS.name())) { + // OAuth authorization code flow authentication + return AuthenticatorType.OAUTH_CLIENT_CREDENTIALS; } else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH.name())) { // OAuth access code Authentication return AuthenticatorType.OAUTH; @@ -273,17 +278,14 @@ static SFLoginOutput openSession( AssertUtil.assertTrue( loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session"); - if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) { - AssertUtil.assertTrue( - loginInput.getOauthLoginInput().getClientId() != null, - "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); - AssertUtil.assertTrue( - loginInput.getOauthLoginInput().getClientSecret() != null, - "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); - OauthAccessTokenProvider accessTokenProvider = - new AuthorizationCodeFlowAccessTokenProvider( + if (AccessTokenProviderFactory.isEligible(getAuthenticator(loginInput))) { + AccessTokenProviderFactory accessTokenProviderFactory = + new AccessTokenProviderFactory( new SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers(), (int) loginInput.getBrowserResponseTimeout().getSeconds()); + AccessTokenProvider accessTokenProvider = + accessTokenProviderFactory.createAccessTokenProvider( + getAuthenticator(loginInput), loginInput); String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput); loginInput.setAuthenticator(AuthenticatorType.OAUTH.name()); loginInput.setToken(oauthAccessToken); @@ -295,7 +297,7 @@ static SFLoginOutput openSession( AssertUtil.assertTrue( loginInput.getUserName() != null, "missing user name for opening session"); } else { - // OAUTH needs either token or passord + // OAUTH needs either token or password AssertUtil.assertTrue( loginInput.getToken() != null || loginInput.getPassword() != null, "missing token or password for opening session"); diff --git a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java index a55e91370..e2c2b3054 100644 --- a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java +++ b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java @@ -46,5 +46,10 @@ public enum AuthenticatorType { /* * Authorization code flow with browser popup */ - OAUTH_AUTHORIZATION_CODE + OAUTH_AUTHORIZATION_CODE, + + /* + * Client credentials flow with clientId and clientSecret as input + */ + OAUTH_CLIENT_CREDENTIALS } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProvider.java similarity index 87% rename from src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java rename to src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProvider.java index 713e6a282..f7d2307b5 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProvider.java @@ -5,7 +5,7 @@ import net.snowflake.client.core.SnowflakeJdbcInternalApi; @SnowflakeJdbcInternalApi -public interface OauthAccessTokenProvider { +public interface AccessTokenProvider { String getAccessToken(SFLoginInput loginInput) throws SFException; } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java new file mode 100644 index 000000000..1a2479cc6 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java @@ -0,0 +1,67 @@ +package net.snowflake.client.core.auth.oauth; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import net.snowflake.client.core.AssertUtil; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SFLoginInput; +import net.snowflake.client.core.SessionUtilExternalBrowser; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.core.auth.AuthenticatorType; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; + +@SnowflakeJdbcInternalApi +public class AccessTokenProviderFactory { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(AccessTokenProviderFactory.class); + private static final AuthenticatorType[] ELIGIBLE_AUTH_TYPES = { + AuthenticatorType.OAUTH_AUTHORIZATION_CODE, AuthenticatorType.OAUTH_CLIENT_CREDENTIALS + }; + + private final SessionUtilExternalBrowser.AuthExternalBrowserHandlers browserHandler; + private final int browserAuthorizationTimeoutSeconds; + + public AccessTokenProviderFactory( + SessionUtilExternalBrowser.AuthExternalBrowserHandlers browserHandler, + int browserAuthorizationTimeoutSeconds) { + this.browserHandler = browserHandler; + this.browserAuthorizationTimeoutSeconds = browserAuthorizationTimeoutSeconds; + } + + public AccessTokenProvider createAccessTokenProvider( + AuthenticatorType authenticatorType, SFLoginInput loginInput) throws SFException { + switch (authenticatorType) { + case OAUTH_AUTHORIZATION_CODE: + assertContainsClientCredentials(loginInput); + return new OAuthAuthorizationCodeAccessTokenProvider( + browserHandler, browserAuthorizationTimeoutSeconds); + case OAUTH_CLIENT_CREDENTIALS: + assertContainsClientCredentials(loginInput); + return new OAuthClientCredentialsAccessTokenProvider(); + default: + logger.error("Unsupported authenticator type: " + authenticatorType); + throw new SFException(ErrorCode.INTERNAL_ERROR); + } + } + + public static Set getEligible() { + return new HashSet<>(Arrays.asList(ELIGIBLE_AUTH_TYPES)); + } + + public static boolean isEligible(AuthenticatorType authenticatorType) { + return getEligible().contains(authenticatorType); + } + + private void assertContainsClientCredentials(SFLoginInput loginInput) throws SFException { + AssertUtil.assertTrue( + loginInput.getOauthLoginInput().getClientId() != null, + "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); + AssertUtil.assertTrue( + loginInput.getOauthLoginInput().getClientSecret() != null, + "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java similarity index 79% rename from src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java rename to src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java index e13cb1b39..e2327a97c 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java @@ -14,7 +14,6 @@ import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; import com.nimbusds.oauth2.sdk.auth.Secret; -import com.nimbusds.oauth2.sdk.http.HTTPRequest; import com.nimbusds.oauth2.sdk.id.ClientID; import com.nimbusds.oauth2.sdk.id.State; import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; @@ -38,30 +37,23 @@ import net.snowflake.client.log.SFLogger; import net.snowflake.client.log.SFLoggerFactory; import org.apache.http.NameValuePair; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.client.utils.URLEncodedUtils; -import org.apache.http.entity.StringEntity; @SnowflakeJdbcInternalApi -public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider { +public class OAuthAuthorizationCodeAccessTokenProvider implements AccessTokenProvider { private static final SFLogger logger = - SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class); - - private static final String SNOWFLAKE_AUTHORIZE_ENDPOINT = "/oauth/authorize"; - private static final String SNOWFLAKE_TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; + SFLoggerFactory.getLogger(OAuthAuthorizationCodeAccessTokenProvider.class); private static final String DEFAULT_REDIRECT_HOST = "http://localhost:8001"; private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect"; private static final String DEFAULT_REDIRECT_URI = DEFAULT_REDIRECT_HOST + REDIRECT_URI_ENDPOINT; - public static final String DEFAULT_SESSION_ROLE_SCOPE_PREFIX = "session:role:"; private final AuthExternalBrowserHandlers browserHandler; private final ObjectMapper objectMapper = new ObjectMapper(); private final int browserAuthorizationTimeoutSeconds; - public AuthorizationCodeFlowAccessTokenProvider( + public OAuthAuthorizationCodeAccessTokenProvider( AuthExternalBrowserHandlers browserHandler, int browserAuthorizationTimeoutSeconds) { this.browserHandler = browserHandler; this.browserAuthorizationTimeoutSeconds = browserAuthorizationTimeoutSeconds; @@ -70,11 +62,14 @@ public AuthorizationCodeFlowAccessTokenProvider( @Override public String getAccessToken(SFLoginInput loginInput) throws SFException { try { + logger.debug("Starting OAuth authorization code authentication flow..."); CodeVerifier pkceVerifier = new CodeVerifier(); AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput, pkceVerifier); return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode, pkceVerifier); } catch (Exception e) { - logger.error("Error during OAuth authorization code flow", e); + logger.error( + "Error during OAuth authorization code flow. Verify configuration passed to driver and IdP (URLs, grant types, scope, etc.)", + e); throw new SFException(e, ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, e.getMessage()); } } @@ -98,10 +93,11 @@ private String exchangeAuthorizationCodeForAccessToken( TokenRequest request = buildTokenRequest(loginInput, authorizationCode, pkceVerifier); URI requestUri = request.getEndpointURI(); logger.debug( - "Requesting access token from: {}", requestUri.getAuthority() + requestUri.getPath()); + "Requesting OAuth access token from: {}", + requestUri.getAuthority() + requestUri.getPath()); String tokenResponse = HttpUtil.executeGeneralRequest( - convertToBaseRequest(request.toHTTPRequest()), + OAuthUtil.convertToBaseRequest(request.toHTTPRequest()), loginInput.getLoginTimeout(), loginInput.getAuthTimeout(), loginInput.getSocketTimeoutInMillis(), @@ -185,14 +181,15 @@ private static AuthorizationRequest buildAuthorizationRequest( ClientID clientID = new ClientID(oauthLoginInput.getClientId()); URI callback = buildRedirectUri(oauthLoginInput); State state = new State(256); - String scope = getScope(loginInput); + String scope = OAuthUtil.getScope(loginInput); return new AuthorizationRequest.Builder(new ResponseType(ResponseType.Value.CODE), clientID) .scope(new Scope(scope)) .state(state) .redirectionURI(callback) .codeChallenge(pkceVerifier, CodeChallengeMethod.S256) .endpointURI( - getAuthorizationUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl())) + OAuthUtil.getAuthorizationUrl( + loginInput.getOauthLoginInput(), loginInput.getServerUrl())) .build(); } @@ -205,9 +202,9 @@ private static TokenRequest buildTokenRequest( new ClientSecretBasic( new ClientID(loginInput.getOauthLoginInput().getClientId()), new Secret(loginInput.getOauthLoginInput().getClientSecret())); - Scope scope = new Scope(getScope(loginInput)); + Scope scope = new Scope(OAuthUtil.getScope(loginInput)); return new TokenRequest( - getTokenRequestUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()), + OAuthUtil.getTokenRequestUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()), clientAuthentication, codeGrant, scope); @@ -220,29 +217,4 @@ private static URI buildRedirectUri(SFOauthLoginInput oauthLoginInput) { : DEFAULT_REDIRECT_URI; return URI.create(redirectUri); } - - private static HttpRequestBase convertToBaseRequest(HTTPRequest request) { - HttpPost baseRequest = new HttpPost(request.getURI()); - baseRequest.setEntity(new StringEntity(request.getBody(), StandardCharsets.UTF_8)); - request.getHeaderMap().forEach((key, values) -> baseRequest.addHeader(key, values.get(0))); - return baseRequest; - } - - private static URI getAuthorizationUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { - return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalAuthorizationUrl()) - ? URI.create(oauthLoginInput.getExternalAuthorizationUrl()) - : URI.create(serverUrl + SNOWFLAKE_AUTHORIZE_ENDPOINT); - } - - private static URI getTokenRequestUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { - return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalTokenRequestUrl()) - ? URI.create(oauthLoginInput.getExternalTokenRequestUrl()) - : URI.create(serverUrl + SNOWFLAKE_TOKEN_REQUEST_ENDPOINT); - } - - private static String getScope(SFLoginInput loginInput) { - return (!StringUtils.isNullOrEmpty(loginInput.getOauthLoginInput().getScope())) - ? loginInput.getOauthLoginInput().getScope() - : DEFAULT_SESSION_ROLE_SCOPE_PREFIX + loginInput.getRole(); - } } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java new file mode 100644 index 000000000..67aee6a72 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java @@ -0,0 +1,72 @@ +package net.snowflake.client.core.auth.oauth; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.nimbusds.oauth2.sdk.ClientCredentialsGrant; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; +import java.net.URI; +import net.snowflake.client.core.HttpUtil; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SFLoginInput; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; + +@SnowflakeJdbcInternalApi +public class OAuthClientCredentialsAccessTokenProvider implements AccessTokenProvider { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(OAuthClientCredentialsAccessTokenProvider.class); + + private final ObjectMapper objectMapper = new ObjectMapper(); + + @Override + public String getAccessToken(SFLoginInput loginInput) throws SFException { + try { + logger.debug("Starting OAuth authorization code authentication flow..."); + TokenRequest tokenRequest = buildTokenRequest(loginInput); + TokenResponseDTO tokenResponse = requestForAccessToken(loginInput, tokenRequest); + return tokenResponse.getAccessToken(); + } catch (Exception e) { + logger.error("Error during OAuth client credentials code flow", e); + throw new SFException(e, ErrorCode.OAUTH_CLIENT_CREDENTIALS_FLOW_ERROR, e.getMessage()); + } + } + + private TokenResponseDTO requestForAccessToken(SFLoginInput loginInput, TokenRequest tokenRequest) + throws Exception { + URI requestUri = tokenRequest.getEndpointURI(); + logger.debug( + "Requesting OAuth access token from: {}", requestUri.getAuthority() + requestUri.getPath()); + String tokenResponse = + HttpUtil.executeGeneralRequest( + OAuthUtil.convertToBaseRequest(tokenRequest.toHTTPRequest()), + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeoutInMillis(), + 0, + loginInput.getHttpClientSettingsKey()); + TokenResponseDTO tokenResponseDTO = + objectMapper.readValue(tokenResponse, TokenResponseDTO.class); + logger.debug( + "Received OAuth access token from: {}", requestUri.getAuthority() + requestUri.getPath()); + return tokenResponseDTO; + } + + private static TokenRequest buildTokenRequest(SFLoginInput loginInput) { + URI tokenRequestUrl = + OAuthUtil.getTokenRequestUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()); + ClientAuthentication clientAuthentication = + new ClientSecretBasic( + new ClientID(loginInput.getOauthLoginInput().getClientId()), + new Secret(loginInput.getOauthLoginInput().getClientSecret())); + Scope scope = new Scope(OAuthUtil.getScope(loginInput)); + return new TokenRequest( + tokenRequestUrl, clientAuthentication, new ClientCredentialsGrant(), scope); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java new file mode 100644 index 000000000..732e0fd11 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java @@ -0,0 +1,46 @@ +package net.snowflake.client.core.auth.oauth; + +import com.amazonaws.util.StringUtils; +import com.nimbusds.oauth2.sdk.http.HTTPRequest; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import net.snowflake.client.core.SFLoginInput; +import net.snowflake.client.core.SFOauthLoginInput; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.StringEntity; + +@SnowflakeJdbcInternalApi +public class OAuthUtil { + + private static final String SNOWFLAKE_AUTHORIZE_ENDPOINT = "/oauth/authorize"; + private static final String SNOWFLAKE_TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; + + private static final String DEFAULT_SESSION_ROLE_SCOPE_PREFIX = "session:role:"; + + public static HttpRequestBase convertToBaseRequest(HTTPRequest request) { + HttpPost baseRequest = new HttpPost(request.getURI()); + baseRequest.setEntity(new StringEntity(request.getBody(), StandardCharsets.UTF_8)); + request.getHeaderMap().forEach((key, values) -> baseRequest.addHeader(key, values.get(0))); + return baseRequest; + } + + public static URI getAuthorizationUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { + return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalAuthorizationUrl()) + ? URI.create(oauthLoginInput.getExternalAuthorizationUrl()) + : URI.create(serverUrl + SNOWFLAKE_AUTHORIZE_ENDPOINT); + } + + public static URI getTokenRequestUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { + return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalTokenRequestUrl()) + ? URI.create(oauthLoginInput.getExternalTokenRequestUrl()) + : URI.create(serverUrl + SNOWFLAKE_TOKEN_REQUEST_ENDPOINT); + } + + public static String getScope(SFLoginInput loginInput) { + return (!StringUtils.isNullOrEmpty(loginInput.getOauthLoginInput().getScope())) + ? loginInput.getOauthLoginInput().getScope() + : DEFAULT_SESSION_ROLE_SCOPE_PREFIX + loginInput.getRole(); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java index cd9ee019b..765cbc657 100644 --- a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java +++ b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java @@ -84,7 +84,8 @@ public enum ErrorCode { GCP_SERVICE_ERROR(200061, SqlState.SYSTEM_ERROR), AUTHENTICATOR_REQUEST_TIMEOUT(200062, SqlState.CONNECTION_EXCEPTION), INVALID_STRUCT_DATA(200063, SqlState.DATA_EXCEPTION), - OAUTH_AUTHORIZATION_CODE_FLOW_ERROR(200064, SqlState.CONNECTION_EXCEPTION); + OAUTH_AUTHORIZATION_CODE_FLOW_ERROR(200064, SqlState.CONNECTION_EXCEPTION), + OAUTH_CLIENT_CREDENTIALS_FLOW_ERROR(200065, SqlState.CONNECTION_EXCEPTION); public static final String errorMessageResource = "net.snowflake.client.jdbc.jdbc_error_messages"; diff --git a/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties b/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties index 43c96b5cf..cd88e8190 100644 --- a/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties +++ b/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties @@ -82,5 +82,5 @@ Error message={3}, Extended error info={4} 200061=GCS operation failed: Operation={0}, Error code={1}, Message={2}, Reason={3} 200062=Authentication timed out. 200063=Invalid data - Cannot be parsed and converted to structured type. -200064=Error during OAuth authentication: {0} - +200064=Error during OAuth Authorization Code authentication: {0} +200065=Error during OAuth Client Credentials authentication: {0} diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index 3104ce7e9..f028b8f8e 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -24,6 +24,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import net.snowflake.client.core.auth.AuthenticatorType; /** Base test class with common constants, data structures and methods */ public class AbstractDriverIT { @@ -324,6 +325,8 @@ public static Connection getConnection( properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? properties.put("insecureMode", false); // use OCSP for all tests. + properties.put("authenticator", AuthenticatorType.OAUTH_CLIENT_CREDENTIALS.name()); + if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout)); } diff --git a/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java index 9e6153812..ccce86add 100644 --- a/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java @@ -11,8 +11,8 @@ import net.snowflake.client.core.SFException; import net.snowflake.client.core.SFLoginInput; import net.snowflake.client.core.SFOauthLoginInput; -import net.snowflake.client.core.auth.oauth.AuthorizationCodeFlowAccessTokenProvider; -import net.snowflake.client.core.auth.oauth.OauthAccessTokenProvider; +import net.snowflake.client.core.auth.oauth.AccessTokenProvider; +import net.snowflake.client.core.auth.oauth.OAuthAuthorizationCodeAccessTokenProvider; import org.apache.http.HttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; @@ -51,8 +51,8 @@ public void successfulFlowScenario() throws SFException { SFLoginInput loginInput = createLoginInputStub("http://localhost:8009/snowflake/oauth-redirect", null, null); - OauthAccessTokenProvider provider = - new AuthorizationCodeFlowAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); + AccessTokenProvider provider = + new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); String accessToken = provider.getAccessToken(loginInput); Assertions.assertFalse(StringUtils.isNullOrEmpty(accessToken)); @@ -68,8 +68,8 @@ public void customUrlsScenario() throws SFException { String.format("http://%s:%d/authorization", WIREMOCK_HOST, wiremockHttpPort), String.format("http://%s:%d/tokenrequest", WIREMOCK_HOST, wiremockHttpPort)); - OauthAccessTokenProvider provider = - new AuthorizationCodeFlowAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); + AccessTokenProvider provider = + new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); String accessToken = provider.getAccessToken(loginInput); Assertions.assertFalse(StringUtils.isNullOrEmpty(accessToken)); @@ -82,8 +82,8 @@ public void browserTimeoutFlowScenario() { SFLoginInput loginInput = createLoginInputStub("http://localhost:8004/snowflake/oauth-redirect", null, null); - OauthAccessTokenProvider provider = - new AuthorizationCodeFlowAccessTokenProvider(wiremockProxyRequestBrowserHandler, 1); + AccessTokenProvider provider = + new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 1); SFException e = Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( @@ -97,8 +97,8 @@ public void invalidScopeFlowScenario() { importMappingFromResources(INVALID_SCOPE_SCENARIO_MAPPING); SFLoginInput loginInput = createLoginInputStub("http://localhost:8002/snowflake/oauth-redirect", null, null); - OauthAccessTokenProvider provider = - new AuthorizationCodeFlowAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); + AccessTokenProvider provider = + new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); SFException e = Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( @@ -113,8 +113,8 @@ public void tokenRequestErrorFlowScenario() { SFLoginInput loginInput = createLoginInputStub("http://localhost:8003/snowflake/oauth-redirect", null, null); - OauthAccessTokenProvider provider = - new AuthorizationCodeFlowAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); + AccessTokenProvider provider = + new OAuthAuthorizationCodeAccessTokenProvider(wiremockProxyRequestBrowserHandler, 30); SFException e = Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( From a45c797957269563ff919eada456453ae8c0ce4c Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Mon, 9 Dec 2024 22:39:05 +0100 Subject: [PATCH 2/7] Add client credentials working flow --- .../net/snowflake/client/core/HttpUtil.java | 6 - .../snowflake/client/core/SessionUtil.java | 1 + .../oauth/AccessTokenProviderFactory.java | 12 +- ...hAuthorizationCodeAccessTokenProvider.java | 4 +- ...hClientCredentialsAccessTokenProvider.java | 4 +- .../client/core/auth/oauth/OAuthUtil.java | 8 +- .../snowflake/client/AbstractDriverIT.java | 4 +- .../oauth/AccessTokenProviderFactoryTest.java | 84 ++++++++++++++ .../client/core/auth/oauth/OAuthUtilTest.java | 66 +++++++++++ ...> OAuthAuthorizationCodeFlowLatestIT.java} | 8 +- .../OAuthClientCredentialsFlowLatestIT.java | 109 ++++++++++++++++++ .../successful_scenario_mapping.json | 30 +++++ .../token_request_error_scenario_mapping.json | 29 +++++ 13 files changed, 341 insertions(+), 24 deletions(-) create mode 100644 src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java create mode 100644 src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java rename src/test/java/net/snowflake/client/jdbc/{OauthAuthorizationCodeFlowLatestIT.java => OAuthAuthorizationCodeFlowLatestIT.java} (95%) create mode 100644 src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java create mode 100644 src/test/resources/oauth/client_credentials/successful_scenario_mapping.json create mode 100644 src/test/resources/oauth/client_credentials/token_request_error_scenario_mapping.json diff --git a/src/main/java/net/snowflake/client/core/HttpUtil.java b/src/main/java/net/snowflake/client/core/HttpUtil.java index 7a6b92ada..7f2507566 100644 --- a/src/main/java/net/snowflake/client/core/HttpUtil.java +++ b/src/main/java/net/snowflake/client/core/HttpUtil.java @@ -836,12 +836,6 @@ private static String executeRequestInternal( stopwatch.stop(); } - writer = new StringWriter(); - try (InputStream ins = response.getEntity().getContent()) { - IOUtils.copy(ins, writer, "UTF-8"); - } - theString = writer.toString(); - if (response == null || response.getStatusLine().getStatusCode() != 200) { logger.error("Error executing request: {}", requestInfoScrubbed); diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index dafbdf7fd..ef7682ac9 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -289,6 +289,7 @@ static SFLoginOutput openSession( String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput); loginInput.setAuthenticator(AuthenticatorType.OAUTH.name()); loginInput.setToken(oauthAccessToken); + loginInput.setUserName("0oalpyiuy8rmozhjZ5d7"); } final AuthenticatorType authenticator = getAuthenticator(loginInput); diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java index 1a2479cc6..105fb8a4d 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java @@ -36,11 +36,13 @@ public AccessTokenProvider createAccessTokenProvider( AuthenticatorType authenticatorType, SFLoginInput loginInput) throws SFException { switch (authenticatorType) { case OAUTH_AUTHORIZATION_CODE: - assertContainsClientCredentials(loginInput); + assertContainsClientCredentials(loginInput, authenticatorType); return new OAuthAuthorizationCodeAccessTokenProvider( browserHandler, browserAuthorizationTimeoutSeconds); case OAUTH_CLIENT_CREDENTIALS: - assertContainsClientCredentials(loginInput); + assertContainsClientCredentials(loginInput, authenticatorType); + AssertUtil.assertTrue(loginInput.getOauthLoginInput().getExternalTokenRequestUrl() != null, + "passing externalTokenRequestUrl is required for OAUTH_CLIENT_CREDENTIALS authentication"); return new OAuthClientCredentialsAccessTokenProvider(); default: logger.error("Unsupported authenticator type: " + authenticatorType); @@ -56,12 +58,12 @@ public static boolean isEligible(AuthenticatorType authenticatorType) { return getEligible().contains(authenticatorType); } - private void assertContainsClientCredentials(SFLoginInput loginInput) throws SFException { + private void assertContainsClientCredentials(SFLoginInput loginInput, AuthenticatorType authenticatorType) throws SFException { AssertUtil.assertTrue( loginInput.getOauthLoginInput().getClientId() != null, - "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); + String.format("passing clientId is required for %s authentication", authenticatorType.name())); AssertUtil.assertTrue( loginInput.getOauthLoginInput().getClientSecret() != null, - "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); + String.format("passing clientSecret is required for %s authentication", authenticatorType.name())); } } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java index e2327a97c..80280dcc9 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java @@ -181,7 +181,7 @@ private static AuthorizationRequest buildAuthorizationRequest( ClientID clientID = new ClientID(oauthLoginInput.getClientId()); URI callback = buildRedirectUri(oauthLoginInput); State state = new State(256); - String scope = OAuthUtil.getScope(loginInput); + String scope = OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole()); return new AuthorizationRequest.Builder(new ResponseType(ResponseType.Value.CODE), clientID) .scope(new Scope(scope)) .state(state) @@ -202,7 +202,7 @@ private static TokenRequest buildTokenRequest( new ClientSecretBasic( new ClientID(loginInput.getOauthLoginInput().getClientId()), new Secret(loginInput.getOauthLoginInput().getClientSecret())); - Scope scope = new Scope(OAuthUtil.getScope(loginInput)); + Scope scope = new Scope(OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole())); return new TokenRequest( OAuthUtil.getTokenRequestUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()), clientAuthentication, diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java index 67aee6a72..909fa23ad 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java @@ -33,7 +33,7 @@ public String getAccessToken(SFLoginInput loginInput) throws SFException { TokenResponseDTO tokenResponse = requestForAccessToken(loginInput, tokenRequest); return tokenResponse.getAccessToken(); } catch (Exception e) { - logger.error("Error during OAuth client credentials code flow", e); + logger.error("Error during OAuth client credentials code flow. Verify configuration passed to driver and IdP (URLs, grant types, scope, etc.)", e); throw new SFException(e, ErrorCode.OAUTH_CLIENT_CREDENTIALS_FLOW_ERROR, e.getMessage()); } } @@ -65,7 +65,7 @@ private static TokenRequest buildTokenRequest(SFLoginInput loginInput) { new ClientSecretBasic( new ClientID(loginInput.getOauthLoginInput().getClientId()), new Secret(loginInput.getOauthLoginInput().getClientSecret())); - Scope scope = new Scope(OAuthUtil.getScope(loginInput)); + Scope scope = new Scope(OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole())); return new TokenRequest( tokenRequestUrl, clientAuthentication, new ClientCredentialsGrant(), scope); } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java index 732e0fd11..6a9120e74 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java @@ -38,9 +38,9 @@ public static URI getTokenRequestUrl(SFOauthLoginInput oauthLoginInput, String s : URI.create(serverUrl + SNOWFLAKE_TOKEN_REQUEST_ENDPOINT); } - public static String getScope(SFLoginInput loginInput) { - return (!StringUtils.isNullOrEmpty(loginInput.getOauthLoginInput().getScope())) - ? loginInput.getOauthLoginInput().getScope() - : DEFAULT_SESSION_ROLE_SCOPE_PREFIX + loginInput.getRole(); + public static String getScope(SFOauthLoginInput oauthLoginInput, String role) { + return (!StringUtils.isNullOrEmpty(oauthLoginInput.getScope())) + ? oauthLoginInput.getScope() + : DEFAULT_SESSION_ROLE_SCOPE_PREFIX + role; } } diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index f028b8f8e..93ad3085a 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -6,6 +6,9 @@ import static org.hamcrest.MatcherAssert.assertThat; import com.google.common.base.Strings; +import net.snowflake.client.core.auth.AuthenticatorType; +import net.snowflake.common.core.ClientAuthnDTO; + import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Paths; @@ -24,7 +27,6 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; -import net.snowflake.client.core.auth.AuthenticatorType; /** Base test class with common constants, data structures and methods */ public class AbstractDriverIT { diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java new file mode 100644 index 000000000..c6a7f2790 --- /dev/null +++ b/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java @@ -0,0 +1,84 @@ +package net.snowflake.client.core.auth.oauth; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SFLoginInput; +import net.snowflake.client.core.SFOauthLoginInput; +import net.snowflake.client.core.auth.AuthenticatorType; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +public class AccessTokenProviderFactoryTest { + + private final AccessTokenProviderFactory providerFactory = new AccessTokenProviderFactory(null, 30); + + @Test + public void shouldProperlyReturnIfAuthenticatorIsEligible() { + Arrays.stream(AuthenticatorType.values()).forEach(authenticatorType -> { + if (authenticatorType == AuthenticatorType.OAUTH_CLIENT_CREDENTIALS || authenticatorType.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) { + Assertions.assertTrue(AccessTokenProviderFactory.isEligible(authenticatorType)); + } else { + Assertions.assertFalse(AccessTokenProviderFactory.isEligible(authenticatorType)); + } + }); + } + + @Test + public void shouldProperlyCreateClientCredentialsAccessTokenProvider() throws SFException { + SFLoginInput loginInput = createLoginInputStub("123", "123", "some/url"); + AccessTokenProvider provider = providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput); + Assertions.assertNotNull(provider); + Assertions.assertInstanceOf(OAuthClientCredentialsAccessTokenProvider.class, provider); + } + + @Test + public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientId() { + SFLoginInput loginInput = createLoginInputStub(null, "123", "some/url"); + SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); + Assertions.assertTrue(e.getMessage().contains("passing clientId is required for OAUTH_CLIENT_CREDENTIALS authentication.")); + } + + @Test + public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientSecret() { + SFLoginInput loginInput = createLoginInputStub("123", null, "some/url"); + SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); + Assertions.assertTrue(e.getMessage().contains("passing clientSecret is required for OAUTH_CLIENT_CREDENTIALS authentication.")); + } + + @Test + public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientAuthzUrl() { + SFLoginInput loginInput = createLoginInputStub("123", "123", null); + SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); + Assertions.assertTrue(e.getMessage().contains("passing externalTokenRequestUrl is required for OAUTH_CLIENT_CREDENTIALS authentication.")); + } + + @Test + public void shouldProperlyCreateAuthorizationCodeAccessTokenProvider() throws SFException { + SFLoginInput loginInput = createLoginInputStub("123", "123", ""); + AccessTokenProvider provider = providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput); + Assertions.assertNotNull(provider); + Assertions.assertInstanceOf(OAuthAuthorizationCodeAccessTokenProvider.class, provider); + } + + @Test + public void shouldFailToCreateAuthzCodeAccessTokenProviderWithoutClientId() { + SFLoginInput loginInput = createLoginInputStub(null, "123", "some/url"); + SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput)); + Assertions.assertTrue(e.getMessage().contains("passing clientId is required for OAUTH_AUTHORIZATION_CODE authentication.")); + } + + @Test + public void shouldFailToCreateAuthzCodeAccessTokenProviderWithoutClientSecret() { + SFLoginInput loginInput = createLoginInputStub("123", null, null); + SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput)); + Assertions.assertTrue(e.getMessage().contains("passing clientSecret is required for OAUTH_AUTHORIZATION_CODE authentication.")); + } + + private SFLoginInput createLoginInputStub(String clientId, String clientSecret, String externalTokenUrl) { + SFLoginInput loginInput = new SFLoginInput(); + loginInput.setOauthLoginInput( + new SFOauthLoginInput(clientId, clientSecret, null, null, externalTokenUrl, null)); + return loginInput; + } +} \ No newline at end of file diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java new file mode 100644 index 000000000..d9e773801 --- /dev/null +++ b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java @@ -0,0 +1,66 @@ +package net.snowflake.client.core.auth.oauth; + +import net.snowflake.client.core.SFOauthLoginInput; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.net.URI; + +public class OAuthUtilTest { + + private static final String BASE_SERVER_URL_FROM_LOGIN_INPUT = "http://some.snowflake.server.com"; + public static final String ROLE_FROM_LOGIN_INPUT = "ANALYST"; + + @Test + public void shouldCreateDefaultAuthorizationUrl() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); + URI authorizationUrl = OAuthUtil.getAuthorizationUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(authorizationUrl); + Assertions.assertEquals("http://some.snowflake.server.com/oauth/authorize", authorizationUrl.toString()); + } + + @Test + public void shouldCreateUserSuppliedAuthorizationUrl() { + SFOauthLoginInput loginInput = createLoginInputStub("http://some.external.authorization.url.com/authz", null, null); + URI tokenRequestUrl = OAuthUtil.getAuthorizationUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(tokenRequestUrl); + Assertions.assertEquals("http://some.external.authorization.url.com/authz", tokenRequestUrl.toString()); + } + + @Test + public void shouldCreateDefaultTokenRequestUrl() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); + URI tokenRequestUrl = OAuthUtil.getTokenRequestUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(tokenRequestUrl); + Assertions.assertEquals("http://some.snowflake.server.com/oauth/token-request", tokenRequestUrl.toString()); + } + + @Test + public void shouldCreateUserSuppliedTokenRequestUrl() { + SFOauthLoginInput loginInput = createLoginInputStub(null, "http://some.external.authorization.url.com/token-request", null); + URI tokenRequestUrl = OAuthUtil.getTokenRequestUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(tokenRequestUrl); + Assertions.assertEquals("http://some.external.authorization.url.com/token-request", tokenRequestUrl.toString()); + } + + @Test + public void shouldCreateDefaultScope() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); + String scope = OAuthUtil.getScope(loginInput, ROLE_FROM_LOGIN_INPUT); + Assertions.assertNotNull(scope); + Assertions.assertEquals("session:role:ANALYST", scope); + } + + @Test + public void shouldCreateUserSuppliedScope() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, "some:custom:SCOPE"); + String scope = OAuthUtil.getScope(loginInput, ROLE_FROM_LOGIN_INPUT); + Assertions.assertNotNull(scope); + Assertions.assertEquals("some:custom:SCOPE", scope); + } + + private SFOauthLoginInput createLoginInputStub(String externalAuthorizationUrl, String externalTokenRequestUrl, String scope) { + return new SFOauthLoginInput(null, null, null, externalAuthorizationUrl, externalTokenRequestUrl, scope); + } + +} \ No newline at end of file diff --git a/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java similarity index 95% rename from src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java rename to src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java index ccce86add..0fc567ffc 100644 --- a/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java @@ -25,7 +25,7 @@ import org.slf4j.LoggerFactory; @Tag(TestTags.CORE) -public class OauthAuthorizationCodeFlowLatestIT extends BaseWiremockTest { +public class OAuthAuthorizationCodeFlowLatestIT extends BaseWiremockTest { private static final String SCENARIOS_BASE_DIR = "/oauth/authorization_code"; private static final String SUCCESSFUL_FLOW_SCENARIO_MAPPINGS = @@ -40,7 +40,7 @@ public class OauthAuthorizationCodeFlowLatestIT extends BaseWiremockTest { SCENARIOS_BASE_DIR + "/custom_urls_scenario_mapping.json"; private static final Logger logger = - LoggerFactory.getLogger(OauthAuthorizationCodeFlowLatestIT.class); + LoggerFactory.getLogger(OAuthAuthorizationCodeFlowLatestIT.class); private final AuthExternalBrowserHandlers wiremockProxyRequestBrowserHandler = new WiremockProxyRequestBrowserHandler(); @@ -120,7 +120,7 @@ public void tokenRequestErrorFlowScenario() { Assertions.assertTrue( e.getMessage() .contains( - "Error during OAuth authentication: JDBC driver encountered communication error. Message: HTTP status=400.")); + "JDBC driver encountered communication error. Message: HTTP status=400")); } private SFLoginInput createLoginInputStub( @@ -129,7 +129,7 @@ private SFLoginInput createLoginInputStub( loginInputStub.setServerUrl(String.format("http://%s:%d/", WIREMOCK_HOST, wiremockHttpPort)); loginInputStub.setOauthLoginInput( new SFOauthLoginInput( - "123", "123", redirectUri, externalAuthorizationUrl, externalTokenUrl, "ANALYST")); + "123", "123", redirectUri, externalAuthorizationUrl, externalTokenUrl, "session:role:ANALYST")); loginInputStub.setSocketTimeout(Duration.ofMinutes(5)); loginInputStub.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); diff --git a/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java new file mode 100644 index 000000000..753029ca4 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java @@ -0,0 +1,109 @@ +package net.snowflake.client.jdbc; + +import com.amazonaws.util.StringUtils; +import net.snowflake.client.category.TestTags; +import net.snowflake.client.core.HttpClientSettingsKey; +import net.snowflake.client.core.OCSPMode; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SFLoginInput; +import net.snowflake.client.core.SFOauthLoginInput; +import net.snowflake.client.core.auth.oauth.AccessTokenProvider; +import net.snowflake.client.core.auth.oauth.OAuthAuthorizationCodeAccessTokenProvider; +import net.snowflake.client.core.auth.oauth.OAuthClientCredentialsAccessTokenProvider; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.time.Duration; + +import static net.snowflake.client.core.SessionUtilExternalBrowser.AuthExternalBrowserHandlers; + +@Tag(TestTags.CORE) +public class OAuthClientCredentialsFlowLatestIT extends BaseWiremockTest { + + private static final String SCENARIOS_BASE_DIR = "/oauth/client_credentials"; + private static final String SUCCESSFUL_FLOW_SCENARIO_MAPPINGS = + SCENARIOS_BASE_DIR + "/successful_scenario_mapping.json"; + private static final String TOKEN_REQUEST_ERROR_SCENARIO_MAPPING = + SCENARIOS_BASE_DIR + "/token_request_error_scenario_mapping.json"; + + private static final Logger logger = + LoggerFactory.getLogger(OAuthClientCredentialsFlowLatestIT.class); + + @Test + public void successfulFlowScenario() throws SFException { + importMappingFromResources(SUCCESSFUL_FLOW_SCENARIO_MAPPINGS); + SFLoginInput loginInput = + createLoginInputStub("http://localhost:8009/snowflake/oauth-redirect"); + + AccessTokenProvider provider = + new OAuthClientCredentialsAccessTokenProvider(); + String accessToken = provider.getAccessToken(loginInput); + + Assertions.assertFalse(StringUtils.isNullOrEmpty(accessToken)); + Assertions.assertEquals("access-token-123", accessToken); + } + + @Test + public void tokenRequestErrorFlowScenario() { + importMappingFromResources(TOKEN_REQUEST_ERROR_SCENARIO_MAPPING); + SFLoginInput loginInput = + createLoginInputStub("http://localhost:8003/snowflake/oauth-redirect"); + + AccessTokenProvider provider = + new OAuthClientCredentialsAccessTokenProvider(); + SFException e = + Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); + Assertions.assertTrue( + e.getMessage() + .contains( + "JDBC driver encountered communication error. Message: HTTP status=400")); + } + + private SFLoginInput createLoginInputStub( + String redirectUri) { + SFLoginInput loginInputStub = new SFLoginInput(); + loginInputStub.setServerUrl(String.format("http://%s:%d/", WIREMOCK_HOST, wiremockHttpPort)); + loginInputStub.setOauthLoginInput( + new SFOauthLoginInput( + "123", "123", redirectUri, null, String.format("http://%s:%d/oauth/token-request", WIREMOCK_HOST, wiremockHttpPort), "session:role:ANALYST")); + loginInputStub.setSocketTimeout(Duration.ofMinutes(5)); + loginInputStub.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); + + return loginInputStub; + } + + static class WiremockProxyRequestBrowserHandler implements AuthExternalBrowserHandlers { + @Override + public HttpPost build(URI uri) { + // do nothing + return null; + } + + @Override + public void openBrowser(String ssoUrl) { + try (CloseableHttpClient client = HttpClients.createDefault()) { + logger.debug("executing browser request to redirect uri: {}", ssoUrl); + HttpResponse response = client.execute(new HttpGet(ssoUrl)); + if (response.getStatusLine().getStatusCode() != 200) { + throw new RuntimeException("Invalid response from " + ssoUrl); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void output(String msg) { + // do nothing + } + } +} diff --git a/src/test/resources/oauth/client_credentials/successful_scenario_mapping.json b/src/test/resources/oauth/client_credentials/successful_scenario_mapping.json new file mode 100644 index 000000000..cf8dd32da --- /dev/null +++ b/src/test/resources/oauth/client_credentials/successful_scenario_mapping.json @@ -0,0 +1,30 @@ +{ + "mappings": [ + { + "scenarioName": "Successful OAuth client credentials flow", + "requiredScenarioState": "Started", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST" + } + ] + }, + "response": { + "status": 200, + "body": "{ \"access_token\" : \"access-token-123\", \"refresh_token\" : \"123\", \"token_type\" : \"Bearer\", \"username\" : \"user\", \"scope\" : \"refresh_token session:role:ANALYST\", \"expires_in\" : 600, \"refresh_token_expires_in\" : 86399, \"idpInitiated\" : false }" + } + } + ] +} diff --git a/src/test/resources/oauth/client_credentials/token_request_error_scenario_mapping.json b/src/test/resources/oauth/client_credentials/token_request_error_scenario_mapping.json new file mode 100644 index 000000000..72eb97481 --- /dev/null +++ b/src/test/resources/oauth/client_credentials/token_request_error_scenario_mapping.json @@ -0,0 +1,29 @@ +{ + "mappings": [ + { + "scenarioName": "OAuth client credentials flow with token request error", + "requiredScenarioState": "Started", + "newScenarioState": "Acquired access token", + "request": { + "urlPathPattern": "/oauth/token-request.*", + "method": "POST", + "headers": { + "Authorization": { + "contains": "Basic" + }, + "Content-Type": { + "contains": "application/x-www-form-urlencoded; charset=UTF-8" + } + }, + "bodyPatterns": [ + { + "contains": "grant_type=client_credentials&scope=session%3Arole%3AANALYST" + } + ] + }, + "response": { + "status": 400 + } + } + ] +} \ No newline at end of file From fd42d991836dd7949995907f26e55fea22e20d79 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Mon, 9 Dec 2024 22:39:43 +0100 Subject: [PATCH 3/7] Reformat --- .../oauth/AccessTokenProviderFactory.java | 14 +- ...hAuthorizationCodeAccessTokenProvider.java | 3 +- ...hClientCredentialsAccessTokenProvider.java | 7 +- .../client/core/auth/oauth/OAuthUtil.java | 1 - .../snowflake/client/AbstractDriverIT.java | 4 +- .../oauth/AccessTokenProviderFactoryTest.java | 170 +++++++++++------- .../client/core/auth/oauth/OAuthUtilTest.java | 113 ++++++------ .../OAuthAuthorizationCodeFlowLatestIT.java | 10 +- .../OAuthClientCredentialsFlowLatestIT.java | 29 ++- 9 files changed, 208 insertions(+), 143 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java index 105fb8a4d..3d7f56587 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java @@ -41,8 +41,9 @@ public AccessTokenProvider createAccessTokenProvider( browserHandler, browserAuthorizationTimeoutSeconds); case OAUTH_CLIENT_CREDENTIALS: assertContainsClientCredentials(loginInput, authenticatorType); - AssertUtil.assertTrue(loginInput.getOauthLoginInput().getExternalTokenRequestUrl() != null, - "passing externalTokenRequestUrl is required for OAUTH_CLIENT_CREDENTIALS authentication"); + AssertUtil.assertTrue( + loginInput.getOauthLoginInput().getExternalTokenRequestUrl() != null, + "passing externalTokenRequestUrl is required for OAUTH_CLIENT_CREDENTIALS authentication"); return new OAuthClientCredentialsAccessTokenProvider(); default: logger.error("Unsupported authenticator type: " + authenticatorType); @@ -58,12 +59,15 @@ public static boolean isEligible(AuthenticatorType authenticatorType) { return getEligible().contains(authenticatorType); } - private void assertContainsClientCredentials(SFLoginInput loginInput, AuthenticatorType authenticatorType) throws SFException { + private void assertContainsClientCredentials( + SFLoginInput loginInput, AuthenticatorType authenticatorType) throws SFException { AssertUtil.assertTrue( loginInput.getOauthLoginInput().getClientId() != null, - String.format("passing clientId is required for %s authentication", authenticatorType.name())); + String.format( + "passing clientId is required for %s authentication", authenticatorType.name())); AssertUtil.assertTrue( loginInput.getOauthLoginInput().getClientSecret() != null, - String.format("passing clientSecret is required for %s authentication", authenticatorType.name())); + String.format( + "passing clientSecret is required for %s authentication", authenticatorType.name())); } } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java index 80280dcc9..b45c032e4 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java @@ -202,7 +202,8 @@ private static TokenRequest buildTokenRequest( new ClientSecretBasic( new ClientID(loginInput.getOauthLoginInput().getClientId()), new Secret(loginInput.getOauthLoginInput().getClientSecret())); - Scope scope = new Scope(OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole())); + Scope scope = + new Scope(OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole())); return new TokenRequest( OAuthUtil.getTokenRequestUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()), clientAuthentication, diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java index 909fa23ad..948f7852c 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java @@ -33,7 +33,9 @@ public String getAccessToken(SFLoginInput loginInput) throws SFException { TokenResponseDTO tokenResponse = requestForAccessToken(loginInput, tokenRequest); return tokenResponse.getAccessToken(); } catch (Exception e) { - logger.error("Error during OAuth client credentials code flow. Verify configuration passed to driver and IdP (URLs, grant types, scope, etc.)", e); + logger.error( + "Error during OAuth client credentials code flow. Verify configuration passed to driver and IdP (URLs, grant types, scope, etc.)", + e); throw new SFException(e, ErrorCode.OAUTH_CLIENT_CREDENTIALS_FLOW_ERROR, e.getMessage()); } } @@ -65,7 +67,8 @@ private static TokenRequest buildTokenRequest(SFLoginInput loginInput) { new ClientSecretBasic( new ClientID(loginInput.getOauthLoginInput().getClientId()), new Secret(loginInput.getOauthLoginInput().getClientSecret())); - Scope scope = new Scope(OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole())); + Scope scope = + new Scope(OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole())); return new TokenRequest( tokenRequestUrl, clientAuthentication, new ClientCredentialsGrant(), scope); } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java index 6a9120e74..ddb307abd 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java @@ -4,7 +4,6 @@ import com.nimbusds.oauth2.sdk.http.HTTPRequest; import java.net.URI; import java.nio.charset.StandardCharsets; -import net.snowflake.client.core.SFLoginInput; import net.snowflake.client.core.SFOauthLoginInput; import net.snowflake.client.core.SnowflakeJdbcInternalApi; import org.apache.http.client.methods.HttpPost; diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index 93ad3085a..f028b8f8e 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -6,9 +6,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import com.google.common.base.Strings; -import net.snowflake.client.core.auth.AuthenticatorType; -import net.snowflake.common.core.ClientAuthnDTO; - import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Paths; @@ -27,6 +24,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import net.snowflake.client.core.auth.AuthenticatorType; /** Base test class with common constants, data structures and methods */ public class AbstractDriverIT { diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java index c6a7f2790..54747ee8b 100644 --- a/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java +++ b/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java @@ -1,5 +1,6 @@ package net.snowflake.client.core.auth.oauth; +import java.util.Arrays; import net.snowflake.client.core.SFException; import net.snowflake.client.core.SFLoginInput; import net.snowflake.client.core.SFOauthLoginInput; @@ -7,78 +8,123 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.Arrays; - public class AccessTokenProviderFactoryTest { - private final AccessTokenProviderFactory providerFactory = new AccessTokenProviderFactory(null, 30); + private final AccessTokenProviderFactory providerFactory = + new AccessTokenProviderFactory(null, 30); - @Test - public void shouldProperlyReturnIfAuthenticatorIsEligible() { - Arrays.stream(AuthenticatorType.values()).forEach(authenticatorType -> { - if (authenticatorType == AuthenticatorType.OAUTH_CLIENT_CREDENTIALS || authenticatorType.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) { + @Test + public void shouldProperlyReturnIfAuthenticatorIsEligible() { + Arrays.stream(AuthenticatorType.values()) + .forEach( + authenticatorType -> { + if (authenticatorType == AuthenticatorType.OAUTH_CLIENT_CREDENTIALS + || authenticatorType.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) { Assertions.assertTrue(AccessTokenProviderFactory.isEligible(authenticatorType)); - } else { + } else { Assertions.assertFalse(AccessTokenProviderFactory.isEligible(authenticatorType)); - } - }); - } + } + }); + } - @Test - public void shouldProperlyCreateClientCredentialsAccessTokenProvider() throws SFException { - SFLoginInput loginInput = createLoginInputStub("123", "123", "some/url"); - AccessTokenProvider provider = providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput); - Assertions.assertNotNull(provider); - Assertions.assertInstanceOf(OAuthClientCredentialsAccessTokenProvider.class, provider); - } + @Test + public void shouldProperlyCreateClientCredentialsAccessTokenProvider() throws SFException { + SFLoginInput loginInput = createLoginInputStub("123", "123", "some/url"); + AccessTokenProvider provider = + providerFactory.createAccessTokenProvider( + AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput); + Assertions.assertNotNull(provider); + Assertions.assertInstanceOf(OAuthClientCredentialsAccessTokenProvider.class, provider); + } - @Test - public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientId() { - SFLoginInput loginInput = createLoginInputStub(null, "123", "some/url"); - SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); - Assertions.assertTrue(e.getMessage().contains("passing clientId is required for OAUTH_CLIENT_CREDENTIALS authentication.")); - } + @Test + public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientId() { + SFLoginInput loginInput = createLoginInputStub(null, "123", "some/url"); + SFException e = + Assertions.assertThrows( + SFException.class, + () -> + providerFactory.createAccessTokenProvider( + AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); + Assertions.assertTrue( + e.getMessage() + .contains("passing clientId is required for OAUTH_CLIENT_CREDENTIALS authentication.")); + } - @Test - public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientSecret() { - SFLoginInput loginInput = createLoginInputStub("123", null, "some/url"); - SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); - Assertions.assertTrue(e.getMessage().contains("passing clientSecret is required for OAUTH_CLIENT_CREDENTIALS authentication.")); - } + @Test + public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientSecret() { + SFLoginInput loginInput = createLoginInputStub("123", null, "some/url"); + SFException e = + Assertions.assertThrows( + SFException.class, + () -> + providerFactory.createAccessTokenProvider( + AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); + Assertions.assertTrue( + e.getMessage() + .contains( + "passing clientSecret is required for OAUTH_CLIENT_CREDENTIALS authentication.")); + } - @Test - public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientAuthzUrl() { - SFLoginInput loginInput = createLoginInputStub("123", "123", null); - SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); - Assertions.assertTrue(e.getMessage().contains("passing externalTokenRequestUrl is required for OAUTH_CLIENT_CREDENTIALS authentication.")); - } + @Test + public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientAuthzUrl() { + SFLoginInput loginInput = createLoginInputStub("123", "123", null); + SFException e = + Assertions.assertThrows( + SFException.class, + () -> + providerFactory.createAccessTokenProvider( + AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); + Assertions.assertTrue( + e.getMessage() + .contains( + "passing externalTokenRequestUrl is required for OAUTH_CLIENT_CREDENTIALS authentication.")); + } - @Test - public void shouldProperlyCreateAuthorizationCodeAccessTokenProvider() throws SFException { - SFLoginInput loginInput = createLoginInputStub("123", "123", ""); - AccessTokenProvider provider = providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput); - Assertions.assertNotNull(provider); - Assertions.assertInstanceOf(OAuthAuthorizationCodeAccessTokenProvider.class, provider); - } + @Test + public void shouldProperlyCreateAuthorizationCodeAccessTokenProvider() throws SFException { + SFLoginInput loginInput = createLoginInputStub("123", "123", ""); + AccessTokenProvider provider = + providerFactory.createAccessTokenProvider( + AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput); + Assertions.assertNotNull(provider); + Assertions.assertInstanceOf(OAuthAuthorizationCodeAccessTokenProvider.class, provider); + } - @Test - public void shouldFailToCreateAuthzCodeAccessTokenProviderWithoutClientId() { - SFLoginInput loginInput = createLoginInputStub(null, "123", "some/url"); - SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput)); - Assertions.assertTrue(e.getMessage().contains("passing clientId is required for OAUTH_AUTHORIZATION_CODE authentication.")); - } + @Test + public void shouldFailToCreateAuthzCodeAccessTokenProviderWithoutClientId() { + SFLoginInput loginInput = createLoginInputStub(null, "123", "some/url"); + SFException e = + Assertions.assertThrows( + SFException.class, + () -> + providerFactory.createAccessTokenProvider( + AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput)); + Assertions.assertTrue( + e.getMessage() + .contains("passing clientId is required for OAUTH_AUTHORIZATION_CODE authentication.")); + } - @Test - public void shouldFailToCreateAuthzCodeAccessTokenProviderWithoutClientSecret() { - SFLoginInput loginInput = createLoginInputStub("123", null, null); - SFException e = Assertions.assertThrows(SFException.class, () -> providerFactory.createAccessTokenProvider(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput)); - Assertions.assertTrue(e.getMessage().contains("passing clientSecret is required for OAUTH_AUTHORIZATION_CODE authentication.")); - } + @Test + public void shouldFailToCreateAuthzCodeAccessTokenProviderWithoutClientSecret() { + SFLoginInput loginInput = createLoginInputStub("123", null, null); + SFException e = + Assertions.assertThrows( + SFException.class, + () -> + providerFactory.createAccessTokenProvider( + AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput)); + Assertions.assertTrue( + e.getMessage() + .contains( + "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE authentication.")); + } - private SFLoginInput createLoginInputStub(String clientId, String clientSecret, String externalTokenUrl) { - SFLoginInput loginInput = new SFLoginInput(); - loginInput.setOauthLoginInput( - new SFOauthLoginInput(clientId, clientSecret, null, null, externalTokenUrl, null)); - return loginInput; - } -} \ No newline at end of file + private SFLoginInput createLoginInputStub( + String clientId, String clientSecret, String externalTokenUrl) { + SFLoginInput loginInput = new SFLoginInput(); + loginInput.setOauthLoginInput( + new SFOauthLoginInput(clientId, clientSecret, null, null, externalTokenUrl, null)); + return loginInput; + } +} diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java index d9e773801..bb6baded0 100644 --- a/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java +++ b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java @@ -1,66 +1,77 @@ package net.snowflake.client.core.auth.oauth; +import java.net.URI; import net.snowflake.client.core.SFOauthLoginInput; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.net.URI; - public class OAuthUtilTest { - private static final String BASE_SERVER_URL_FROM_LOGIN_INPUT = "http://some.snowflake.server.com"; - public static final String ROLE_FROM_LOGIN_INPUT = "ANALYST"; - - @Test - public void shouldCreateDefaultAuthorizationUrl() { - SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); - URI authorizationUrl = OAuthUtil.getAuthorizationUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); - Assertions.assertNotNull(authorizationUrl); - Assertions.assertEquals("http://some.snowflake.server.com/oauth/authorize", authorizationUrl.toString()); - } + private static final String BASE_SERVER_URL_FROM_LOGIN_INPUT = "http://some.snowflake.server.com"; + public static final String ROLE_FROM_LOGIN_INPUT = "ANALYST"; - @Test - public void shouldCreateUserSuppliedAuthorizationUrl() { - SFOauthLoginInput loginInput = createLoginInputStub("http://some.external.authorization.url.com/authz", null, null); - URI tokenRequestUrl = OAuthUtil.getAuthorizationUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); - Assertions.assertNotNull(tokenRequestUrl); - Assertions.assertEquals("http://some.external.authorization.url.com/authz", tokenRequestUrl.toString()); - } + @Test + public void shouldCreateDefaultAuthorizationUrl() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); + URI authorizationUrl = + OAuthUtil.getAuthorizationUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(authorizationUrl); + Assertions.assertEquals( + "http://some.snowflake.server.com/oauth/authorize", authorizationUrl.toString()); + } - @Test - public void shouldCreateDefaultTokenRequestUrl() { - SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); - URI tokenRequestUrl = OAuthUtil.getTokenRequestUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); - Assertions.assertNotNull(tokenRequestUrl); - Assertions.assertEquals("http://some.snowflake.server.com/oauth/token-request", tokenRequestUrl.toString()); - } + @Test + public void shouldCreateUserSuppliedAuthorizationUrl() { + SFOauthLoginInput loginInput = + createLoginInputStub("http://some.external.authorization.url.com/authz", null, null); + URI tokenRequestUrl = + OAuthUtil.getAuthorizationUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(tokenRequestUrl); + Assertions.assertEquals( + "http://some.external.authorization.url.com/authz", tokenRequestUrl.toString()); + } - @Test - public void shouldCreateUserSuppliedTokenRequestUrl() { - SFOauthLoginInput loginInput = createLoginInputStub(null, "http://some.external.authorization.url.com/token-request", null); - URI tokenRequestUrl = OAuthUtil.getTokenRequestUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); - Assertions.assertNotNull(tokenRequestUrl); - Assertions.assertEquals("http://some.external.authorization.url.com/token-request", tokenRequestUrl.toString()); - } + @Test + public void shouldCreateDefaultTokenRequestUrl() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); + URI tokenRequestUrl = + OAuthUtil.getTokenRequestUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(tokenRequestUrl); + Assertions.assertEquals( + "http://some.snowflake.server.com/oauth/token-request", tokenRequestUrl.toString()); + } - @Test - public void shouldCreateDefaultScope() { - SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); - String scope = OAuthUtil.getScope(loginInput, ROLE_FROM_LOGIN_INPUT); - Assertions.assertNotNull(scope); - Assertions.assertEquals("session:role:ANALYST", scope); - } + @Test + public void shouldCreateUserSuppliedTokenRequestUrl() { + SFOauthLoginInput loginInput = + createLoginInputStub( + null, "http://some.external.authorization.url.com/token-request", null); + URI tokenRequestUrl = + OAuthUtil.getTokenRequestUrl(loginInput, BASE_SERVER_URL_FROM_LOGIN_INPUT); + Assertions.assertNotNull(tokenRequestUrl); + Assertions.assertEquals( + "http://some.external.authorization.url.com/token-request", tokenRequestUrl.toString()); + } - @Test - public void shouldCreateUserSuppliedScope() { - SFOauthLoginInput loginInput = createLoginInputStub(null, null, "some:custom:SCOPE"); - String scope = OAuthUtil.getScope(loginInput, ROLE_FROM_LOGIN_INPUT); - Assertions.assertNotNull(scope); - Assertions.assertEquals("some:custom:SCOPE", scope); - } + @Test + public void shouldCreateDefaultScope() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, null); + String scope = OAuthUtil.getScope(loginInput, ROLE_FROM_LOGIN_INPUT); + Assertions.assertNotNull(scope); + Assertions.assertEquals("session:role:ANALYST", scope); + } - private SFOauthLoginInput createLoginInputStub(String externalAuthorizationUrl, String externalTokenRequestUrl, String scope) { - return new SFOauthLoginInput(null, null, null, externalAuthorizationUrl, externalTokenRequestUrl, scope); - } + @Test + public void shouldCreateUserSuppliedScope() { + SFOauthLoginInput loginInput = createLoginInputStub(null, null, "some:custom:SCOPE"); + String scope = OAuthUtil.getScope(loginInput, ROLE_FROM_LOGIN_INPUT); + Assertions.assertNotNull(scope); + Assertions.assertEquals("some:custom:SCOPE", scope); + } -} \ No newline at end of file + private SFOauthLoginInput createLoginInputStub( + String externalAuthorizationUrl, String externalTokenRequestUrl, String scope) { + return new SFOauthLoginInput( + null, null, null, externalAuthorizationUrl, externalTokenRequestUrl, scope); + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java index 0fc567ffc..50c501302 100644 --- a/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java @@ -119,8 +119,7 @@ public void tokenRequestErrorFlowScenario() { Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( e.getMessage() - .contains( - "JDBC driver encountered communication error. Message: HTTP status=400")); + .contains("JDBC driver encountered communication error. Message: HTTP status=400")); } private SFLoginInput createLoginInputStub( @@ -129,7 +128,12 @@ private SFLoginInput createLoginInputStub( loginInputStub.setServerUrl(String.format("http://%s:%d/", WIREMOCK_HOST, wiremockHttpPort)); loginInputStub.setOauthLoginInput( new SFOauthLoginInput( - "123", "123", redirectUri, externalAuthorizationUrl, externalTokenUrl, "session:role:ANALYST")); + "123", + "123", + redirectUri, + externalAuthorizationUrl, + externalTokenUrl, + "session:role:ANALYST")); loginInputStub.setSocketTimeout(Duration.ofMinutes(5)); loginInputStub.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); diff --git a/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java index 753029ca4..ac0db9868 100644 --- a/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java @@ -1,6 +1,10 @@ package net.snowflake.client.jdbc; +import static net.snowflake.client.core.SessionUtilExternalBrowser.AuthExternalBrowserHandlers; + import com.amazonaws.util.StringUtils; +import java.net.URI; +import java.time.Duration; import net.snowflake.client.category.TestTags; import net.snowflake.client.core.HttpClientSettingsKey; import net.snowflake.client.core.OCSPMode; @@ -8,7 +12,6 @@ import net.snowflake.client.core.SFLoginInput; import net.snowflake.client.core.SFOauthLoginInput; import net.snowflake.client.core.auth.oauth.AccessTokenProvider; -import net.snowflake.client.core.auth.oauth.OAuthAuthorizationCodeAccessTokenProvider; import net.snowflake.client.core.auth.oauth.OAuthClientCredentialsAccessTokenProvider; import org.apache.http.HttpResponse; import org.apache.http.client.methods.HttpGet; @@ -21,11 +24,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.net.URI; -import java.time.Duration; - -import static net.snowflake.client.core.SessionUtilExternalBrowser.AuthExternalBrowserHandlers; - @Tag(TestTags.CORE) public class OAuthClientCredentialsFlowLatestIT extends BaseWiremockTest { @@ -44,8 +42,7 @@ public void successfulFlowScenario() throws SFException { SFLoginInput loginInput = createLoginInputStub("http://localhost:8009/snowflake/oauth-redirect"); - AccessTokenProvider provider = - new OAuthClientCredentialsAccessTokenProvider(); + AccessTokenProvider provider = new OAuthClientCredentialsAccessTokenProvider(); String accessToken = provider.getAccessToken(loginInput); Assertions.assertFalse(StringUtils.isNullOrEmpty(accessToken)); @@ -58,23 +55,25 @@ public void tokenRequestErrorFlowScenario() { SFLoginInput loginInput = createLoginInputStub("http://localhost:8003/snowflake/oauth-redirect"); - AccessTokenProvider provider = - new OAuthClientCredentialsAccessTokenProvider(); + AccessTokenProvider provider = new OAuthClientCredentialsAccessTokenProvider(); SFException e = Assertions.assertThrows(SFException.class, () -> provider.getAccessToken(loginInput)); Assertions.assertTrue( e.getMessage() - .contains( - "JDBC driver encountered communication error. Message: HTTP status=400")); + .contains("JDBC driver encountered communication error. Message: HTTP status=400")); } - private SFLoginInput createLoginInputStub( - String redirectUri) { + private SFLoginInput createLoginInputStub(String redirectUri) { SFLoginInput loginInputStub = new SFLoginInput(); loginInputStub.setServerUrl(String.format("http://%s:%d/", WIREMOCK_HOST, wiremockHttpPort)); loginInputStub.setOauthLoginInput( new SFOauthLoginInput( - "123", "123", redirectUri, null, String.format("http://%s:%d/oauth/token-request", WIREMOCK_HOST, wiremockHttpPort), "session:role:ANALYST")); + "123", + "123", + redirectUri, + null, + String.format("http://%s:%d/oauth/token-request", WIREMOCK_HOST, wiremockHttpPort), + "session:role:ANALYST")); loginInputStub.setSocketTimeout(Duration.ofMinutes(5)); loginInputStub.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); From 44b90a1464a5f1518102be90b00b202a68313333 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Tue, 10 Dec 2024 11:57:30 +0100 Subject: [PATCH 4/7] Refactor --- linkage-checker-exclusion-rules.xml | 5 +++++ .../core/auth/oauth/AccessTokenProvider.java | 4 ++++ .../oauth/AccessTokenProviderFactory.java | 4 ++++ ...hAuthorizationCodeAccessTokenProvider.java | 19 ++++++++++++------- ...hClientCredentialsAccessTokenProvider.java | 4 ++++ .../core/auth/oauth/TokenResponseDTO.java | 6 ++++-- .../oauth/AccessTokenProviderFactoryTest.java | 4 ++++ .../client/core/auth/oauth/OAuthUtilTest.java | 4 ++++ .../OAuthAuthorizationCodeFlowLatestIT.java | 4 ++++ .../OAuthClientCredentialsFlowLatestIT.java | 4 ++++ 10 files changed, 49 insertions(+), 9 deletions(-) diff --git a/linkage-checker-exclusion-rules.xml b/linkage-checker-exclusion-rules.xml index eb7207b32..cc7640d8f 100644 --- a/linkage-checker-exclusion-rules.xml +++ b/linkage-checker-exclusion-rules.xml @@ -14,6 +14,11 @@ Optional + + + + Optional + diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProvider.java index f7d2307b5..19e3ffe4a 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProvider.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.core.auth.oauth; import net.snowflake.client.core.SFException; diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java index 3d7f56587..a2153e326 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.core.auth.oauth; import java.util.Arrays; diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java index b45c032e4..ef9df7af6 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAuthorizationCodeAccessTokenProvider.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.core.auth.oauth; import static net.snowflake.client.core.SessionUtilExternalBrowser.AuthExternalBrowserHandlers; @@ -106,7 +110,9 @@ private String exchangeAuthorizationCodeForAccessToken( TokenResponseDTO tokenResponseDTO = objectMapper.readValue(tokenResponse, TokenResponseDTO.class); logger.debug( - "Received OAuth access token from: {}", requestUri.getAuthority() + requestUri.getPath()); + "Received OAuth access token from: {}{}", + requestUri.getAuthority(), + requestUri.getPath()); return tokenResponseDTO.getAccessToken(); } catch (Exception e) { logger.error("Error during making OAuth access token request", e); @@ -124,13 +130,12 @@ private AuthorizationCode letUserAuthorize( browserHandler.openBrowser(authorizeRequestURI.toString()); String code = codeFuture.get(this.browserAuthorizationTimeoutSeconds, TimeUnit.SECONDS); return new AuthorizationCode(code); + } catch (TimeoutException e) { + throw new SFException( + e, + ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, + "Authorization request timed out. Snowflake driver did not receive authorization code back to the redirect URI. Verify your security integration and driver configuration."); } catch (Exception e) { - if (e instanceof TimeoutException) { - throw new SFException( - e, - ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, - "Authorization request timed out. Snowflake driver did not receive authorization code back to the redirect URI. Verify your security integration and driver configuration."); - } throw new SFException(e, ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, e.getMessage()); } finally { logger.debug("Stopping OAuth redirect URI server @ {}", httpServer.getAddress()); diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java index 948f7852c..31a837e8c 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.core.auth.oauth; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java index a44c21465..430c642a7 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java @@ -1,10 +1,12 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.core.auth.oauth; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import net.snowflake.client.core.SnowflakeJdbcInternalApi; -@SnowflakeJdbcInternalApi class TokenResponseDTO { private final String accessToken; diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java index 54747ee8b..d05c2ed66 100644 --- a/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java +++ b/src/test/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactoryTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.core.auth.oauth; import java.util.Arrays; diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java index bb6baded0..096fb83b6 100644 --- a/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java +++ b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthUtilTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.core.auth.oauth; import java.net.URI; diff --git a/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java index 50c501302..7bceb10f4 100644 --- a/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/OAuthAuthorizationCodeFlowLatestIT.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.jdbc; import static net.snowflake.client.core.SessionUtilExternalBrowser.AuthExternalBrowserHandlers; diff --git a/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java index ac0db9868..6caca3694 100644 --- a/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/OAuthClientCredentialsFlowLatestIT.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.client.jdbc; import static net.snowflake.client.core.SessionUtilExternalBrowser.AuthExternalBrowserHandlers; From 2c5896e825c594d1c945524bd717a2fa7d0f817d Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Tue, 10 Dec 2024 12:57:00 +0100 Subject: [PATCH 5/7] CR suggestions --- src/main/java/net/snowflake/client/core/AssertUtil.java | 1 + src/main/java/net/snowflake/client/core/HttpUtil.java | 1 - src/main/java/net/snowflake/client/core/SessionUtil.java | 1 - .../auth/oauth/OAuthClientCredentialsAccessTokenProvider.java | 4 +++- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/AssertUtil.java b/src/main/java/net/snowflake/client/core/AssertUtil.java index 91237cf3b..149634238 100644 --- a/src/main/java/net/snowflake/client/core/AssertUtil.java +++ b/src/main/java/net/snowflake/client/core/AssertUtil.java @@ -16,6 +16,7 @@ public class AssertUtil { * @param internalErrorMesg The error message to display if condition is false * @throws SFException Will be thrown if condition is false */ + @SnowflakeJdbcInternalApi public static void assertTrue(boolean condition, String internalErrorMesg) throws SFException { if (!condition) { throw new SFException(ErrorCode.INTERNAL_ERROR, internalErrorMesg); diff --git a/src/main/java/net/snowflake/client/core/HttpUtil.java b/src/main/java/net/snowflake/client/core/HttpUtil.java index 7f2507566..23b83df09 100644 --- a/src/main/java/net/snowflake/client/core/HttpUtil.java +++ b/src/main/java/net/snowflake/client/core/HttpUtil.java @@ -842,7 +842,6 @@ private static String executeRequestInternal( SnowflakeUtil.logResponseDetails(response, logger); if (response != null) { - EntityUtils.consume(response.getEntity()); } diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index ef7682ac9..dafbdf7fd 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -289,7 +289,6 @@ static SFLoginOutput openSession( String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput); loginInput.setAuthenticator(AuthenticatorType.OAUTH.name()); loginInput.setToken(oauthAccessToken); - loginInput.setUserName("0oalpyiuy8rmozhjZ5d7"); } final AuthenticatorType authenticator = getAuthenticator(loginInput); diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java index 31a837e8c..0a84893ac 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthClientCredentialsAccessTokenProvider.java @@ -48,7 +48,9 @@ private TokenResponseDTO requestForAccessToken(SFLoginInput loginInput, TokenReq throws Exception { URI requestUri = tokenRequest.getEndpointURI(); logger.debug( - "Requesting OAuth access token from: {}", requestUri.getAuthority() + requestUri.getPath()); + "Requesting OAuth access token from: {}{}", + requestUri.getAuthority(), + requestUri.getPath()); String tokenResponse = HttpUtil.executeGeneralRequest( OAuthUtil.convertToBaseRequest(tokenRequest.toHTTPRequest()), From 94d2bc7e2dfcbc2c1a1e9f36b8cc37844dc1f43b Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Tue, 10 Dec 2024 14:44:43 +0100 Subject: [PATCH 6/7] Remove comment --- src/main/java/net/snowflake/client/core/SessionUtil.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index dafbdf7fd..8bd1d1626 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -223,12 +223,10 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) { } else if (loginInput .getAuthenticator() .equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE.name())) { - // OAuth authorization code flow authentication return AuthenticatorType.OAUTH_AUTHORIZATION_CODE; } else if (loginInput .getAuthenticator() .equalsIgnoreCase(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS.name())) { - // OAuth authorization code flow authentication return AuthenticatorType.OAUTH_CLIENT_CREDENTIALS; } else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH.name())) { // OAuth access code Authentication From 70a7475df9a477de1d84f1a141feadc761846475 Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Tue, 10 Dec 2024 15:06:51 +0100 Subject: [PATCH 7/7] CR suggestions applied --- .../core/auth/oauth/AccessTokenProviderFactory.java | 6 ++---- .../snowflake/client/core/auth/oauth/OAuthUtil.java | 12 +++++------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java index a2153e326..90ae34f70 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AccessTokenProviderFactory.java @@ -22,9 +22,7 @@ public class AccessTokenProviderFactory { private static final SFLogger logger = SFLoggerFactory.getLogger(AccessTokenProviderFactory.class); - private static final AuthenticatorType[] ELIGIBLE_AUTH_TYPES = { - AuthenticatorType.OAUTH_AUTHORIZATION_CODE, AuthenticatorType.OAUTH_CLIENT_CREDENTIALS - }; + private static final Set ELIGIBLE_AUTH_TYPES = new HashSet<>(Arrays.asList(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, AuthenticatorType.OAUTH_CLIENT_CREDENTIALS)); private final SessionUtilExternalBrowser.AuthExternalBrowserHandlers browserHandler; private final int browserAuthorizationTimeoutSeconds; @@ -56,7 +54,7 @@ public AccessTokenProvider createAccessTokenProvider( } public static Set getEligible() { - return new HashSet<>(Arrays.asList(ELIGIBLE_AUTH_TYPES)); + return ELIGIBLE_AUTH_TYPES; } public static boolean isEligible(AuthenticatorType authenticatorType) { diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java index ddb307abd..4e214b654 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthUtil.java @@ -5,39 +5,37 @@ import java.net.URI; import java.nio.charset.StandardCharsets; import net.snowflake.client.core.SFOauthLoginInput; -import net.snowflake.client.core.SnowflakeJdbcInternalApi; import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.entity.StringEntity; -@SnowflakeJdbcInternalApi -public class OAuthUtil { +class OAuthUtil { private static final String SNOWFLAKE_AUTHORIZE_ENDPOINT = "/oauth/authorize"; private static final String SNOWFLAKE_TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; private static final String DEFAULT_SESSION_ROLE_SCOPE_PREFIX = "session:role:"; - public static HttpRequestBase convertToBaseRequest(HTTPRequest request) { + static HttpRequestBase convertToBaseRequest(HTTPRequest request) { HttpPost baseRequest = new HttpPost(request.getURI()); baseRequest.setEntity(new StringEntity(request.getBody(), StandardCharsets.UTF_8)); request.getHeaderMap().forEach((key, values) -> baseRequest.addHeader(key, values.get(0))); return baseRequest; } - public static URI getAuthorizationUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { + static URI getAuthorizationUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalAuthorizationUrl()) ? URI.create(oauthLoginInput.getExternalAuthorizationUrl()) : URI.create(serverUrl + SNOWFLAKE_AUTHORIZE_ENDPOINT); } - public static URI getTokenRequestUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { + static URI getTokenRequestUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) { return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalTokenRequestUrl()) ? URI.create(oauthLoginInput.getExternalTokenRequestUrl()) : URI.create(serverUrl + SNOWFLAKE_TOKEN_REQUEST_ENDPOINT); } - public static String getScope(SFOauthLoginInput oauthLoginInput, String role) { + static String getScope(SFOauthLoginInput oauthLoginInput, String role) { return (!StringUtils.isNullOrEmpty(oauthLoginInput.getScope())) ? oauthLoginInput.getScope() : DEFAULT_SESSION_ROLE_SCOPE_PREFIX + role;