diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index 48292c566..0b9fe94cd 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -54,7 +54,7 @@ public class SFLoginInput { private boolean enableClientStoreTemporaryCredential; private boolean enableClientRequestMfaToken; - //OAuth + // OAuth private int redirectUriPort = -1; private String clientId; private String clientSecret; @@ -64,7 +64,7 @@ public class SFLoginInput { // Additional headers to add for Snowsight. Map additionalHttpHeadersForSnowsight; - SFLoginInput() {} + public SFLoginInput() {} Duration getBrowserResponseTimeout() { return browserResponseTimeout; @@ -79,7 +79,7 @@ public String getServerUrl() { return serverUrl; } - SFLoginInput setServerUrl(String serverUrl) { + public SFLoginInput setServerUrl(String serverUrl) { this.serverUrl = serverUrl; return this; } @@ -247,7 +247,7 @@ public int getSocketTimeoutInMillis() { return (int) socketTimeout.toMillis(); } - SFLoginInput setSocketTimeout(Duration socketTimeout) { + public SFLoginInput setSocketTimeout(Duration socketTimeout) { this.socketTimeout = socketTimeout; return this; } @@ -397,7 +397,7 @@ public HttpClientSettingsKey getHttpClientSettingsKey() { return httpClientKey; } - SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) { + public SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) { this.httpClientKey = key; return this; } diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index 26aee2872..a9b969772 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -699,7 +699,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException { .setBrowserResponseTimeout(browserResponseTimeout); if (connectionPropertiesMap.containsKey(SFSessionProperty.OAUTH_REDIRECT_URI_PORT)) { - loginInput.setRedirectUriPort((Integer) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI_PORT)); + loginInput.setRedirectUriPort( + (Integer) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI_PORT)); } logger.info( diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index bdd0a4daa..3a52b02e4 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -219,7 +219,9 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) { .equalsIgnoreCase(AuthenticatorType.EXTERNALBROWSER.name())) { // SAML 2.0 compliant service/application return AuthenticatorType.EXTERNALBROWSER; - } else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name())) { + } else if (loginInput + .getAuthenticator() + .equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name())) { // OAuth authorization code flow authentication return AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW; } else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH.name())) { @@ -271,9 +273,15 @@ static SFLoginOutput openSession( loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session"); if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW)) { - AssertUtil.assertTrue(loginInput.getClientId() != null, "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); - AssertUtil.assertTrue(loginInput.getClientSecret() != null, "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); - OauthAccessTokenProvider accessTokenProvider = new AuthorizationCodeFlowAccessTokenProvider(); + AssertUtil.assertTrue( + loginInput.getClientId() != null, + "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); + AssertUtil.assertTrue( + loginInput.getClientSecret() != null, + "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); + OauthAccessTokenProvider accessTokenProvider = + new AuthorizationCodeFlowAccessTokenProvider( + new SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers()); String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput); loginInput.setAuthenticator(AuthenticatorType.OAUTH.name()); loginInput.setToken(oauthAccessToken); diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java index c87db49a2..32e27e844 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java @@ -1,5 +1,7 @@ package net.snowflake.client.core.auth.oauth; +import static net.snowflake.client.core.SessionUtilExternalBrowser.*; + import com.amazonaws.util.StringUtils; import com.fasterxml.jackson.databind.ObjectMapper; import com.nimbusds.oauth2.sdk.AuthorizationCode; @@ -15,7 +17,17 @@ import com.nimbusds.oauth2.sdk.http.HTTPRequest; import com.nimbusds.oauth2.sdk.id.ClientID; import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; +import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; import com.sun.net.httpserver.HttpServer; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import net.snowflake.client.core.HttpUtil; import net.snowflake.client.core.SFException; import net.snowflake.client.core.SFLoginInput; @@ -26,132 +38,154 @@ import org.apache.http.client.methods.HttpRequestBase; import org.apache.http.entity.StringEntity; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -import static net.snowflake.client.core.SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers; - @SnowflakeJdbcInternalApi public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider { - private static final SFLogger logger = SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class); - - private static final String AUTHORIZE_ENDPOINT = "/oauth/authorize"; - private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; - - private static final String REDIRECT_URI_HOST = "localhost"; - private static final int DEFAULT_REDIRECT_URI_PORT = 8001; - private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect"; - public static final String SESSION_ROLE_SCOPE = "session:role"; - - public static int AUTHORIZE_REDIRECT_TIMEOUT_MINUTES = 2; - - private final DefaultAuthExternalBrowserHandlers browserUtil = new DefaultAuthExternalBrowserHandlers(); - private final ObjectMapper objectMapper = new ObjectMapper(); - - @Override - public String getAccessToken(SFLoginInput loginInput) throws SFException { - AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput); - return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode); - } - - private AuthorizationCode requestAuthorizationCode(SFLoginInput loginInput) throws SFException { - try { - AuthorizationRequest request = buildAuthorizationRequest(loginInput); - URI authorizeRequestURI = request.toURI(); - CompletableFuture codeFuture = setupRedirectURIServerForAuthorizationCode(loginInput.getRedirectUriPort()); - letUserAuthorizeViaBrowser(authorizeRequestURI); - String code = codeFuture.get(AUTHORIZE_REDIRECT_TIMEOUT_MINUTES, TimeUnit.MINUTES); - return new AuthorizationCode(code); - } catch (Exception e) { - if (e instanceof TimeoutException) { - logger.error("Authorization request timed out. Did not receive authorization code back to the redirect URI"); - } - throw new RuntimeException(e.getMessage(), e); - } + private static final SFLogger logger = + SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class); + + private static final String SNOWFLAKE_AUTHORIZE_ENDPOINT = "/oauth/authorize"; + private static final String SNOWFLAKE_TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; + + private static final String REDIRECT_URI_HOST = "localhost"; + private static final int DEFAULT_REDIRECT_URI_PORT = 8001; + private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect"; + public static final String SESSION_ROLE_SCOPE = "session:role"; + + public static int AUTHORIZE_REDIRECT_TIMEOUT_MINUTES = 2; + + private final AuthExternalBrowserHandlers browserHandler; + private final ObjectMapper objectMapper = new ObjectMapper(); + + public AuthorizationCodeFlowAccessTokenProvider(AuthExternalBrowserHandlers browserHandler) { + this.browserHandler = browserHandler; + } + + @Override + public String getAccessToken(SFLoginInput loginInput) throws SFException { + CodeVerifier pkceVerifier = new CodeVerifier(); + AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput, pkceVerifier); + return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode, pkceVerifier); + } + + private AuthorizationCode requestAuthorizationCode( + SFLoginInput loginInput, CodeVerifier pkceVerifier) throws SFException { + try { + AuthorizationRequest request = buildAuthorizationRequest(loginInput, pkceVerifier); + URI authorizeRequestURI = request.toURI(); + CompletableFuture codeFuture = + setupRedirectURIServerForAuthorizationCode(loginInput.getRedirectUriPort()); + letUserAuthorizeViaBrowser(authorizeRequestURI); + String code = codeFuture.get(AUTHORIZE_REDIRECT_TIMEOUT_MINUTES, TimeUnit.MINUTES); + return new AuthorizationCode(code); + } catch (Exception e) { + if (e instanceof TimeoutException) { + logger.error( + "Authorization request timed out. Did not receive authorization code back to the redirect URI"); + } + throw new RuntimeException(e); } - - private String exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException { - try { - TokenRequest request = buildTokenRequest(loginInput, authorizationCode); - String tokenResponse = HttpUtil.executeGeneralRequest( - convertTokenRequest(request.toHTTPRequest()), - loginInput.getLoginTimeout(), - loginInput.getAuthTimeout(), - loginInput.getSocketTimeoutInMillis(), - 0, - loginInput.getHttpClientSettingsKey()); - TokenResponseDTO tokenResponseDTO = objectMapper.readValue(tokenResponse, TokenResponseDTO.class); - return tokenResponseDTO.getAccessToken(); - } catch (Exception e) { - throw new RuntimeException(e); - } + } + + private String exchangeAuthorizationCodeForAccessToken( + SFLoginInput loginInput, AuthorizationCode authorizationCode, CodeVerifier pkceVerifier) + throws SFException { + try { + TokenRequest request = buildTokenRequest(loginInput, authorizationCode, pkceVerifier); + String tokenResponse = + HttpUtil.executeGeneralRequest( + convertTokenRequest(request.toHTTPRequest()), + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeoutInMillis(), + 0, + loginInput.getHttpClientSettingsKey()); + TokenResponseDTO tokenResponseDTO = + objectMapper.readValue(tokenResponse, TokenResponseDTO.class); + return tokenResponseDTO.getAccessToken(); + } catch (Exception e) { + throw new RuntimeException(e); } - - private void letUserAuthorizeViaBrowser(URI authorizeRequestURI) throws SFException { - browserUtil.openBrowser(authorizeRequestURI.toString()); - } - - private static CompletableFuture setupRedirectURIServerForAuthorizationCode(int redirectUriPort) throws IOException { - CompletableFuture accessTokenFuture = new CompletableFuture<>(); - int redirectPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT; - HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, redirectPort), 0); - httpServer.createContext(REDIRECT_URI_ENDPOINT, exchange -> { - String authorizationCode = extractAuthorizationCodeFromQueryParameters(exchange.getRequestURI().getQuery()); - if (!StringUtils.isNullOrEmpty(authorizationCode)) { - accessTokenFuture.complete(authorizationCode); - httpServer.stop(0); - } + } + + private void letUserAuthorizeViaBrowser(URI authorizeRequestURI) throws SFException { + browserHandler.openBrowser(authorizeRequestURI.toString()); + } + + private static CompletableFuture setupRedirectURIServerForAuthorizationCode( + int redirectUriPort) throws IOException { + CompletableFuture accessTokenFuture = new CompletableFuture<>(); + int redirectPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT; + HttpServer httpServer = + HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, redirectPort), 0); + httpServer.createContext( + REDIRECT_URI_ENDPOINT, + exchange -> { + exchange.sendResponseHeaders(200, 0); + exchange.getResponseBody().close(); + String authorizationCode = + extractAuthorizationCodeFromQueryParameters(exchange.getRequestURI().getQuery()); + if (!StringUtils.isNullOrEmpty(authorizationCode)) { + accessTokenFuture.complete(authorizationCode); + httpServer.stop(0); + } }); - httpServer.start(); - return accessTokenFuture; - } - - private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput loginInput) throws URISyntaxException { - URI authorizeEndpoint = new URI(loginInput.getServerUrl() + AUTHORIZE_ENDPOINT); - ClientID clientID = new ClientID(loginInput.getClientId()); - Scope scope = new Scope(String.format("%s:%s", SESSION_ROLE_SCOPE, loginInput.getRole())); - URI callback = buildRedirectURI(loginInput.getRedirectUriPort()); - State state = new State(256); - return new AuthorizationRequest.Builder( - new ResponseType(ResponseType.Value.CODE), clientID) - .scope(scope) - .state(state) - .redirectionURI(callback) - .endpointURI(authorizeEndpoint) - .build(); - } - - private static TokenRequest buildTokenRequest(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws URISyntaxException { - URI callback = buildRedirectURI(loginInput.getRedirectUriPort()); - AuthorizationGrant codeGrant = new AuthorizationCodeGrant(authorizationCode, callback); - ClientAuthentication clientAuthentication = new ClientSecretBasic(new ClientID(loginInput.getClientId()), new Secret(loginInput.getClientSecret())); - URI tokenEndpoint = new URI(String.format(loginInput.getServerUrl() + TOKEN_REQUEST_ENDPOINT)); - Scope scope = new Scope(SESSION_ROLE_SCOPE, loginInput.getRole()); - return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, scope); - } - - private static URI buildRedirectURI(int redirectUriPort) throws URISyntaxException { - redirectUriPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT; - return new URI(String.format("http://%s:%s%s", REDIRECT_URI_HOST, redirectUriPort, REDIRECT_URI_ENDPOINT)); - } - - private static String extractAuthorizationCodeFromQueryParameters(String queryParameters) { - String prefix = "code="; - String codeSuffix = queryParameters.substring(queryParameters.indexOf(prefix) + prefix.length()); - return codeSuffix.substring(0, codeSuffix.indexOf("&")); - } - - private static HttpRequestBase convertTokenRequest(HTTPRequest nimbusRequest) { - HttpPost request = new HttpPost(nimbusRequest.getURI()); - request.setEntity(new StringEntity(nimbusRequest.getBody(), StandardCharsets.UTF_8)); - nimbusRequest.getHeaderMap().forEach((key, values) -> request.addHeader(key, values.get(0))); - return request; + httpServer.start(); + return accessTokenFuture; + } + + private static AuthorizationRequest buildAuthorizationRequest( + SFLoginInput loginInput, CodeVerifier pkceVerifier) throws URISyntaxException { + URI authorizeEndpoint = new URI(loginInput.getServerUrl() + SNOWFLAKE_AUTHORIZE_ENDPOINT); + ClientID clientID = new ClientID(loginInput.getClientId()); + Scope scope = new Scope(String.format("%s:%s", SESSION_ROLE_SCOPE, loginInput.getRole())); + URI callback = buildRedirectURI(loginInput.getRedirectUriPort()); + State state = new State(256); + return new AuthorizationRequest.Builder(new ResponseType(ResponseType.Value.CODE), clientID) + .scope(scope) + .state(state) + .redirectionURI(callback) + .codeChallenge(pkceVerifier, CodeChallengeMethod.S256) + .endpointURI(authorizeEndpoint) + .build(); + } + + private static TokenRequest buildTokenRequest( + SFLoginInput loginInput, AuthorizationCode authorizationCode, CodeVerifier pkceVerifier) + throws URISyntaxException { + URI redirectURI = buildRedirectURI(loginInput.getRedirectUriPort()); + AuthorizationGrant codeGrant = + new AuthorizationCodeGrant(authorizationCode, redirectURI, pkceVerifier); + ClientAuthentication clientAuthentication = + new ClientSecretBasic( + new ClientID(loginInput.getClientId()), new Secret(loginInput.getClientSecret())); + URI tokenEndpoint = + new URI(String.format(loginInput.getServerUrl() + SNOWFLAKE_TOKEN_REQUEST_ENDPOINT)); + Scope scope = new Scope(SESSION_ROLE_SCOPE, loginInput.getRole()); + return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, scope); + } + + private static URI buildRedirectURI(int redirectUriPort) throws URISyntaxException { + redirectUriPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT; + return new URI( + String.format("http://%s:%s%s", REDIRECT_URI_HOST, redirectUriPort, REDIRECT_URI_ENDPOINT)); + } + + private static String extractAuthorizationCodeFromQueryParameters(String queryParameters) { + String prefix = "code="; + String codeSuffix = + queryParameters.substring(queryParameters.indexOf(prefix) + prefix.length()); + if (codeSuffix.contains("&")) { + return codeSuffix.substring(0, codeSuffix.indexOf("&")); + } else { + return codeSuffix; } + } + + private static HttpRequestBase convertTokenRequest(HTTPRequest nimbusRequest) { + HttpPost request = new HttpPost(nimbusRequest.getURI()); + request.setEntity(new StringEntity(nimbusRequest.getBody(), StandardCharsets.UTF_8)); + nimbusRequest.getHeaderMap().forEach((key, values) -> request.addHeader(key, values.get(0))); + return request; + } } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java index 60b311953..713e6a282 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java @@ -7,5 +7,5 @@ @SnowflakeJdbcInternalApi public interface OauthAccessTokenProvider { - String getAccessToken(SFLoginInput loginInput) throws SFException; + String getAccessToken(SFLoginInput loginInput) throws SFException; } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java index 6d00bd84c..3db842276 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java @@ -1,69 +1,68 @@ package net.snowflake.client.core.auth.oauth; - import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; class TokenResponseDTO { - private final String accessToken; - private final String refreshToken; - private final String tokenType; - private final String scope; - private final String username; - private final boolean idpInitiated; - private final long expiresIn; - private final long refreshTokenExpiresIn; + private final String accessToken; + private final String refreshToken; + private final String tokenType; + private final String scope; + private final String username; + private final boolean idpInitiated; + private final long expiresIn; + private final long refreshTokenExpiresIn; - @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) - public TokenResponseDTO(@JsonProperty("access_token") String accessToken, - @JsonProperty("refresh_token") String refreshToken, - @JsonProperty("token_type") String tokenType, - @JsonProperty("scope") String scope, - @JsonProperty("username") String username, - @JsonProperty("idp_initiated") boolean idpInitiated, - @JsonProperty("expires_in") long expiresIn, - @JsonProperty("refresh_token_expires_in") long refreshTokenExpiresIn) { - this.accessToken = accessToken; - this.tokenType = tokenType; - this.refreshToken = refreshToken; - this.scope = scope; - this.username = username; - this.idpInitiated = idpInitiated; - this.expiresIn = expiresIn; - this.refreshTokenExpiresIn = refreshTokenExpiresIn; - } + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public TokenResponseDTO( + @JsonProperty("access_token") String accessToken, + @JsonProperty("refresh_token") String refreshToken, + @JsonProperty("token_type") String tokenType, + @JsonProperty("scope") String scope, + @JsonProperty("username") String username, + @JsonProperty("idp_initiated") boolean idpInitiated, + @JsonProperty("expires_in") long expiresIn, + @JsonProperty("refresh_token_expires_in") long refreshTokenExpiresIn) { + this.accessToken = accessToken; + this.tokenType = tokenType; + this.refreshToken = refreshToken; + this.scope = scope; + this.username = username; + this.idpInitiated = idpInitiated; + this.expiresIn = expiresIn; + this.refreshTokenExpiresIn = refreshTokenExpiresIn; + } - public String getAccessToken() { - return accessToken; - } + public String getAccessToken() { + return accessToken; + } - public String getTokenType() { - return tokenType; - } + public String getTokenType() { + return tokenType; + } - public String getRefreshToken() { - return refreshToken; - } + public String getRefreshToken() { + return refreshToken; + } - public String getScope() { - return scope; - } + public String getScope() { + return scope; + } - public long getExpiresIn() { - return expiresIn; - } + public long getExpiresIn() { + return expiresIn; + } - public String getUsername() { - return username; - } + public String getUsername() { + return username; + } - public long getRefreshTokenExpiresIn() { - return refreshTokenExpiresIn; - } + public long getRefreshTokenExpiresIn() { + return refreshTokenExpiresIn; + } - public boolean isIdpInitiated() { - return idpInitiated; - } + public boolean isIdpInitiated() { + return idpInitiated; + } } diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index c370cd8ad..113fedeee 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -6,8 +6,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import com.google.common.base.Strings; -import net.snowflake.client.core.auth.AuthenticatorType; - import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Paths; @@ -26,6 +24,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import net.snowflake.client.core.auth.AuthenticatorType; /** Base test class with common constants, data structures and methods */ public class AbstractDriverIT { @@ -326,6 +325,8 @@ public static Connection getConnection( properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? properties.put("insecureMode", false); // use OCSP for all tests. properties.put("authenticator", AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name()); + properties.put("clientId", "IJF3fOhk3Ap614HGkKFt9+Ow1LA="); + properties.put("clientSecret", "kX0l9bGGLnuLkByufjUeSG0OLoi2Hz/Nw/31pKXqpE4="); if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout)); diff --git a/src/test/java/net/snowflake/client/jdbc/BaseWiremockTest.java b/src/test/java/net/snowflake/client/jdbc/BaseWiremockTest.java index 08069b95c..8089387a2 100644 --- a/src/test/java/net/snowflake/client/jdbc/BaseWiremockTest.java +++ b/src/test/java/net/snowflake/client/jdbc/BaseWiremockTest.java @@ -25,6 +25,7 @@ import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; +import org.apache.http.util.EntityUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assumptions; @@ -232,7 +233,8 @@ protected void addMapping(String mapping) { HttpPost postRequest = createWiremockPostRequest(mapping, "/__admin/mappings"); try (CloseableHttpClient client = HttpClients.createDefault(); CloseableHttpResponse response = client.execute(postRequest)) { - assertEquals(201, response.getStatusLine().getStatusCode()); + String responseBody = EntityUtils.toString(response.getEntity(), "UTF-8"); + assertEquals(201, response.getStatusLine().getStatusCode(), responseBody); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java b/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java new file mode 100644 index 000000000..cbcb8649f --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/OauthAuthorizationCodeFlowLatestIT.java @@ -0,0 +1,136 @@ +package net.snowflake.client.jdbc; + +import static net.snowflake.client.core.SessionUtilExternalBrowser.*; + +import java.io.IOException; +import java.net.URI; +import java.time.Duration; +import net.snowflake.client.category.TestTags; +import net.snowflake.client.core.HttpClientSettingsKey; +import net.snowflake.client.core.OCSPMode; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SFLoginInput; +import net.snowflake.client.core.auth.oauth.AuthorizationCodeFlowAccessTokenProvider; +import net.snowflake.client.core.auth.oauth.OauthAccessTokenProvider; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.platform.commons.util.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@Tag(TestTags.CORE) +public class OauthAuthorizationCodeFlowLatestIT extends BaseWiremockTest { + + public static final String SUCCESSFUL_FLOW_SCENARIO_MAPPINGS = + "{\n" + + " \"mappings\": [\n" + + " {\n" + + " \"scenarioName\": \"Successful OAuth authorization code flow\",\n" + + " \"requiredScenarioState\": \"Started\",\n" + + " \"newScenarioState\": \"Authorized\",\n" + + " \"request\": {\n" + + " \"urlPathPattern\": \"/oauth/authorize.*\",\n" + + " \"method\": \"GET\"\n" + + " },\n" + + " \"response\": {\n" + + " \"status\": 200\n" + + " },\n" + + " \"serveEventListeners\": [\n" + + " {\n" + + " \"name\": \"webhook\",\n" + + " \"parameters\": {\n" + + " \"method\": \"GET\",\n" + + " \"url\": \"http://localhost:8001/snowflake/oauth-redirect?code=123\"\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " {\n" + + " \"scenarioName\": \"Successful OAuth authorization code flow\",\n" + + " \"requiredScenarioState\": \"Authorized\",\n" + + " \"newScenarioState\": \"Acquired access token\",\n" + + " \"request\": {\n" + + " \"urlPathPattern\": \"/oauth/token-request.*\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\n" + + " \"Authorization\": {\n" + + " \"contains\": \"Basic\"\n" + + " },\n" + + " \"Content-Type\": {\n" + + " \"contains\": \"application/x-www-form-urlencoded; charset=UTF-8\"\n" + + " }\n" + + " },\n" + + " \"bodyPatterns\": [{\n" + + " \"contains\": \"grant_type=authorization_code&code=123&redirect_uri=http%3A%2F%2Flocalhost%3A8001%2Fsnowflake%2Foauth-redirect&code_verifier=\"\n" + + " }]\n" + + " },\n" + + " \"response\": {\n" + + " \"status\": 200,\n" + + " \"body\": \"{ \\\"access_token\\\" : \\\"access-token-123\\\", \\\"refresh_token\\\" : \\\"123\\\", \\\"token_type\\\" : \\\"Bearer\\\", \\\"username\\\" : \\\"user\\\", \\\"scope\\\" : \\\"refresh_token session:role:ANALYST\\\", \\\"expires_in\\\" : 600, \\\"refresh_token_expires_in\\\" : 86399, \\\"idpInitiated\\\" : false }\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"importOptions\": {\n" + + " \"duplicatePolicy\": \"IGNORE\",\n" + + " \"deleteAllNotInImport\": true\n" + + " }\n" + + "}"; + + private static final Logger log = + LoggerFactory.getLogger(OauthAuthorizationCodeFlowLatestIT.class); + + AuthExternalBrowserHandlers wiremockProxyRequestBrowserHandler = + new WiremockProxyRequestBrowserHandler(); + + @Test + public void successfulFlowScenario() throws SFException { + importMapping(SUCCESSFUL_FLOW_SCENARIO_MAPPINGS); + SFLoginInput loginInput = createLoginInputStub(); + + OauthAccessTokenProvider provider = + new AuthorizationCodeFlowAccessTokenProvider(wiremockProxyRequestBrowserHandler); + String accessToken = provider.getAccessToken(loginInput); + + Assertions.assertTrue(StringUtils.isNotBlank(accessToken)); + Assertions.assertEquals("access-token-123", accessToken); + } + + private SFLoginInput createLoginInputStub() { + SFLoginInput loginInputStub = new SFLoginInput(); + loginInputStub.setServerUrl(String.format("http://%s:%d/", WIREMOCK_HOST, wiremockHttpPort)); + loginInputStub.setClientSecret("123"); + loginInputStub.setClientId("123"); + loginInputStub.setRole("ANALYST"); + loginInputStub.setSocketTimeout(Duration.ofMinutes(5)); + loginInputStub.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN)); + + return loginInputStub; + } + + static class WiremockProxyRequestBrowserHandler implements AuthExternalBrowserHandlers { + @Override + public HttpPost build(URI uri) { + // do nothing + return null; + } + + @Override + public void openBrowser(String ssoUrl) { + try (CloseableHttpClient client = HttpClients.createDefault()) { + client.execute(new HttpGet(ssoUrl)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void output(String msg) { + // do nothing + } + } +}