From a3dad01fc053c5377b6cfb8d51b1f33ab241df1f Mon Sep 17 00:00:00 2001 From: Dawid Heyman Date: Mon, 2 Dec 2024 16:16:51 +0100 Subject: [PATCH] Implement full flow --- .../snowflake/client/core/SFLoginInput.java | 32 +++++ .../net/snowflake/client/core/SFSession.java | 6 + .../client/core/SFSessionProperty.java | 3 + .../snowflake/client/core/SessionUtil.java | 7 +- .../core/SessionUtilExternalBrowser.java | 2 +- ...horizationCodeFlowAccessTokenProvider.java | 132 ++++++++++++------ .../core/auth/oauth/TokenResponseDTO.java | 69 +++++++++ .../snowflake/client/AbstractDriverIT.java | 3 + 8 files changed, 206 insertions(+), 48 deletions(-) create mode 100644 src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index a6c6afd6d..48292c566 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -54,6 +54,11 @@ public class SFLoginInput { private boolean enableClientStoreTemporaryCredential; private boolean enableClientRequestMfaToken; + //OAuth + private int redirectUriPort = -1; + private String clientId; + private String clientSecret; + private Duration browserResponseTimeout; // Additional headers to add for Snowsight. @@ -417,6 +422,33 @@ SFLoginInput setDisableSamlURLCheck(boolean disableSamlURLCheck) { return this; } + public int getRedirectUriPort() { + return redirectUriPort; + } + + public SFLoginInput setRedirectUriPort(int redirectUriPort) { + this.redirectUriPort = redirectUriPort; + return this; + } + + public String getClientId() { + return clientId; + } + + public SFLoginInput setClientId(String clientId) { + this.clientId = clientId; + return this; + } + + public String getClientSecret() { + return clientSecret; + } + + public SFLoginInput setClientSecret(String clientSecret) { + this.clientSecret = clientSecret; + return this; + } + Map getAdditionalHttpHeadersForSnowsight() { return additionalHttpHeadersForSnowsight; } diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index 59aac9d5b..26aee2872 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -671,6 +671,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException { .setSessionParameters(sessionParametersMap) .setPrivateKey((PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY)) .setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE)) + .setClientId((String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_ID)) + .setClientSecret((String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_SECRET)) .setPrivateKeyBase64( (String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_BASE64)) .setPrivateKeyPwd( @@ -696,6 +698,10 @@ public synchronized void open() throws SFException, SnowflakeSQLException { .setEnableClientRequestMfaToken(enableClientRequestMfaToken) .setBrowserResponseTimeout(browserResponseTimeout); + if (connectionPropertiesMap.containsKey(SFSessionProperty.OAUTH_REDIRECT_URI_PORT)) { + loginInput.setRedirectUriPort((Integer) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI_PORT)); + } + logger.info( "Connecting to {} Snowflake domain", loginInput.getHostFromServerUrl().toLowerCase().endsWith(".cn") ? "CHINA" : "GLOBAL"); diff --git a/src/main/java/net/snowflake/client/core/SFSessionProperty.java b/src/main/java/net/snowflake/client/core/SFSessionProperty.java index 97c0adbc2..db9c386e7 100644 --- a/src/main/java/net/snowflake/client/core/SFSessionProperty.java +++ b/src/main/java/net/snowflake/client/core/SFSessionProperty.java @@ -29,6 +29,9 @@ public enum SFSessionProperty { AUTHENTICATOR("authenticator", false, String.class), OKTA_USERNAME("oktausername", false, String.class), PRIVATE_KEY("privateKey", false, PrivateKey.class), + OAUTH_REDIRECT_URI_PORT("oauthRedirectUriPort", false, Integer.class), + CLIENT_ID("clientID", false, String.class), + CLIENT_SECRET("clientSecret", false, String.class), WAREHOUSE("warehouse", false, String.class), LOGIN_TIMEOUT("loginTimeout", false, Integer.class), NETWORK_TIMEOUT("networkTimeout", false, Integer.class), diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index 17971ad37..bdd0a4daa 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -219,8 +219,11 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) { .equalsIgnoreCase(AuthenticatorType.EXTERNALBROWSER.name())) { // SAML 2.0 compliant service/application return AuthenticatorType.EXTERNALBROWSER; + } else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name())) { + // OAuth authorization code flow authentication + return AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW; } else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH.name())) { - // OAuth Authentication + // OAuth access code Authentication return AuthenticatorType.OAUTH; } else if (loginInput .getAuthenticator() @@ -268,6 +271,8 @@ static SFLoginOutput openSession( loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session"); if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW)) { + AssertUtil.assertTrue(loginInput.getClientId() != null, "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); + AssertUtil.assertTrue(loginInput.getClientSecret() != null, "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication"); OauthAccessTokenProvider accessTokenProvider = new AuthorizationCodeFlowAccessTokenProvider(); String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput); loginInput.setAuthenticator(AuthenticatorType.OAUTH.name()); diff --git a/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java b/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java index 0f83a9642..d15d77799 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java +++ b/src/main/java/net/snowflake/client/core/SessionUtilExternalBrowser.java @@ -61,7 +61,7 @@ public interface AuthExternalBrowserHandlers { void output(String msg); } - static class DefaultAuthExternalBrowserHandlers implements AuthExternalBrowserHandlers { + public static class DefaultAuthExternalBrowserHandlers implements AuthExternalBrowserHandlers { @Override public HttpPost build(URI uri) { return new HttpPost(uri); diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java index a507d505f..c87db49a2 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java @@ -1,97 +1,123 @@ package net.snowflake.client.core.auth.oauth; -import com.nimbusds.oauth2.sdk.AccessTokenResponse; +import com.amazonaws.util.StringUtils; +import com.fasterxml.jackson.databind.ObjectMapper; import com.nimbusds.oauth2.sdk.AuthorizationCode; import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; import com.nimbusds.oauth2.sdk.AuthorizationGrant; import com.nimbusds.oauth2.sdk.AuthorizationRequest; import com.nimbusds.oauth2.sdk.ResponseType; import com.nimbusds.oauth2.sdk.Scope; -import com.nimbusds.oauth2.sdk.TokenErrorResponse; import com.nimbusds.oauth2.sdk.TokenRequest; -import com.nimbusds.oauth2.sdk.TokenResponse; import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.http.HTTPRequest; import com.nimbusds.oauth2.sdk.id.ClientID; import com.nimbusds.oauth2.sdk.id.State; -import com.nimbusds.oauth2.sdk.token.AccessToken; import com.sun.net.httpserver.HttpServer; import net.snowflake.client.core.HttpUtil; import net.snowflake.client.core.SFException; import net.snowflake.client.core.SFLoginInput; import net.snowflake.client.core.SnowflakeJdbcInternalApi; -import net.snowflake.client.jdbc.ErrorCode; -import org.apache.http.client.methods.HttpGet; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.StringEntity; import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static net.snowflake.client.core.SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers; @SnowflakeJdbcInternalApi public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider { + private static final SFLogger logger = SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class); + + private static final String AUTHORIZE_ENDPOINT = "/oauth/authorize"; + private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; + private static final String REDIRECT_URI_HOST = "localhost"; - private static final int REDIRECT_URI_PORT = 8001; - private static final String REDIRECT_URI_PATH = "/oauth-redirect"; + private static final int DEFAULT_REDIRECT_URI_PORT = 8001; + private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect"; + public static final String SESSION_ROLE_SCOPE = "session:role"; + + public static int AUTHORIZE_REDIRECT_TIMEOUT_MINUTES = 2; + + private final DefaultAuthExternalBrowserHandlers browserUtil = new DefaultAuthExternalBrowserHandlers(); + private final ObjectMapper objectMapper = new ObjectMapper(); @Override public String getAccessToken(SFLoginInput loginInput) throws SFException { AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput); - AccessToken accessToken = exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode); - return accessToken.getValue(); + return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode); } private AuthorizationCode requestAuthorizationCode(SFLoginInput loginInput) throws SFException { try { AuthorizationRequest request = buildAuthorizationRequest(loginInput); - URI requestURI = request.toURI(); - HttpUtil.executeGeneralRequest(new HttpGet(requestURI), - loginInput.getLoginTimeout(), - loginInput.getAuthTimeout(), - loginInput.getSocketTimeoutInMillis(), - 0, - loginInput.getHttpClientSettingsKey()); - String code = getAuthorizationCodeFromRedirectURI().join(); + URI authorizeRequestURI = request.toURI(); + CompletableFuture codeFuture = setupRedirectURIServerForAuthorizationCode(loginInput.getRedirectUriPort()); + letUserAuthorizeViaBrowser(authorizeRequestURI); + String code = codeFuture.get(AUTHORIZE_REDIRECT_TIMEOUT_MINUTES, TimeUnit.MINUTES); return new AuthorizationCode(code); } catch (Exception e) { - throw new SFException(e, ErrorCode.INTERNAL_ERROR); + if (e instanceof TimeoutException) { + logger.error("Authorization request timed out. Did not receive authorization code back to the redirect URI"); + } + throw new RuntimeException(e.getMessage(), e); } } - private static AccessToken exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException { + private String exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException { try { TokenRequest request = buildTokenRequest(loginInput, authorizationCode); - TokenResponse response = TokenResponse.parse(request.toHTTPRequest().send()); - if (!response.indicatesSuccess()) { - TokenErrorResponse errorResponse = response.toErrorResponse(); - errorResponse.getErrorObject(); - } - AccessTokenResponse successResponse = response.toSuccessResponse(); - return successResponse.getTokens().getAccessToken(); + String tokenResponse = HttpUtil.executeGeneralRequest( + convertTokenRequest(request.toHTTPRequest()), + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeoutInMillis(), + 0, + loginInput.getHttpClientSettingsKey()); + TokenResponseDTO tokenResponseDTO = objectMapper.readValue(tokenResponse, TokenResponseDTO.class); + return tokenResponseDTO.getAccessToken(); } catch (Exception e) { - throw new SFException(e, ErrorCode.INTERNAL_ERROR); + throw new RuntimeException(e); } } - private static CompletableFuture getAuthorizationCodeFromRedirectURI() throws IOException { + private void letUserAuthorizeViaBrowser(URI authorizeRequestURI) throws SFException { + browserUtil.openBrowser(authorizeRequestURI.toString()); + } + + private static CompletableFuture setupRedirectURIServerForAuthorizationCode(int redirectUriPort) throws IOException { CompletableFuture accessTokenFuture = new CompletableFuture<>(); - HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, REDIRECT_URI_PORT), 0); - httpServer.createContext(REDIRECT_URI_PATH, exchange -> { - String authorizationCode = exchange.getRequestURI().getQuery(); - accessTokenFuture.complete(authorizationCode); - httpServer.stop(0); + int redirectPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT; + HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, redirectPort), 0); + httpServer.createContext(REDIRECT_URI_ENDPOINT, exchange -> { + String authorizationCode = extractAuthorizationCodeFromQueryParameters(exchange.getRequestURI().getQuery()); + if (!StringUtils.isNullOrEmpty(authorizationCode)) { + accessTokenFuture.complete(authorizationCode); + httpServer.stop(0); + } }); + httpServer.start(); return accessTokenFuture; } private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput loginInput) throws URISyntaxException { - URI authorizeEndpoint = new URI(String.format("%s/oauth/authorize", loginInput.getServerUrl())); - ClientID clientID = new ClientID("123"); - Scope scope = new Scope(String.format("session:role:%s", loginInput.getRole())); - URI callback = buildRedirectURI(); + URI authorizeEndpoint = new URI(loginInput.getServerUrl() + AUTHORIZE_ENDPOINT); + ClientID clientID = new ClientID(loginInput.getClientId()); + Scope scope = new Scope(String.format("%s:%s", SESSION_ROLE_SCOPE, loginInput.getRole())); + URI callback = buildRedirectURI(loginInput.getRedirectUriPort()); State state = new State(256); return new AuthorizationRequest.Builder( new ResponseType(ResponseType.Value.CODE), clientID) @@ -102,16 +128,30 @@ private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput login .build(); } - private static URI buildRedirectURI() throws URISyntaxException { - return new URI(String.format("https://%s:%s%s", REDIRECT_URI_HOST, REDIRECT_URI_PORT, REDIRECT_URI_PATH)); - } - private static TokenRequest buildTokenRequest(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws URISyntaxException { - URI callback = buildRedirectURI(); + URI callback = buildRedirectURI(loginInput.getRedirectUriPort()); AuthorizationGrant codeGrant = new AuthorizationCodeGrant(authorizationCode, callback); - ClientAuthentication clientAuthentication = new ClientSecretBasic(new ClientID("123"), new Secret("123")); - URI tokenEndpoint = new URI(String.format("%s/oauth/token-request", loginInput.getServerUrl())); - Scope scope = new Scope("session:role", loginInput.getRole()); + ClientAuthentication clientAuthentication = new ClientSecretBasic(new ClientID(loginInput.getClientId()), new Secret(loginInput.getClientSecret())); + URI tokenEndpoint = new URI(String.format(loginInput.getServerUrl() + TOKEN_REQUEST_ENDPOINT)); + Scope scope = new Scope(SESSION_ROLE_SCOPE, loginInput.getRole()); return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, scope); } + + private static URI buildRedirectURI(int redirectUriPort) throws URISyntaxException { + redirectUriPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT; + return new URI(String.format("http://%s:%s%s", REDIRECT_URI_HOST, redirectUriPort, REDIRECT_URI_ENDPOINT)); + } + + private static String extractAuthorizationCodeFromQueryParameters(String queryParameters) { + String prefix = "code="; + String codeSuffix = queryParameters.substring(queryParameters.indexOf(prefix) + prefix.length()); + return codeSuffix.substring(0, codeSuffix.indexOf("&")); + } + + private static HttpRequestBase convertTokenRequest(HTTPRequest nimbusRequest) { + HttpPost request = new HttpPost(nimbusRequest.getURI()); + request.setEntity(new StringEntity(nimbusRequest.getBody(), StandardCharsets.UTF_8)); + nimbusRequest.getHeaderMap().forEach((key, values) -> request.addHeader(key, values.get(0))); + return request; + } } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java new file mode 100644 index 000000000..6d00bd84c --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java @@ -0,0 +1,69 @@ +package net.snowflake.client.core.auth.oauth; + + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +class TokenResponseDTO { + + private final String accessToken; + private final String refreshToken; + private final String tokenType; + private final String scope; + private final String username; + private final boolean idpInitiated; + private final long expiresIn; + private final long refreshTokenExpiresIn; + + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public TokenResponseDTO(@JsonProperty("access_token") String accessToken, + @JsonProperty("refresh_token") String refreshToken, + @JsonProperty("token_type") String tokenType, + @JsonProperty("scope") String scope, + @JsonProperty("username") String username, + @JsonProperty("idp_initiated") boolean idpInitiated, + @JsonProperty("expires_in") long expiresIn, + @JsonProperty("refresh_token_expires_in") long refreshTokenExpiresIn) { + this.accessToken = accessToken; + this.tokenType = tokenType; + this.refreshToken = refreshToken; + this.scope = scope; + this.username = username; + this.idpInitiated = idpInitiated; + this.expiresIn = expiresIn; + this.refreshTokenExpiresIn = refreshTokenExpiresIn; + } + + public String getAccessToken() { + return accessToken; + } + + public String getTokenType() { + return tokenType; + } + + public String getRefreshToken() { + return refreshToken; + } + + public String getScope() { + return scope; + } + + public long getExpiresIn() { + return expiresIn; + } + + public String getUsername() { + return username; + } + + public long getRefreshTokenExpiresIn() { + return refreshTokenExpiresIn; + } + + public boolean isIdpInitiated() { + return idpInitiated; + } +} diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index 3104ce7e9..c370cd8ad 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -6,6 +6,8 @@ import static org.hamcrest.MatcherAssert.assertThat; import com.google.common.base.Strings; +import net.snowflake.client.core.auth.AuthenticatorType; + import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Paths; @@ -323,6 +325,7 @@ public static Connection getConnection( properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? properties.put("insecureMode", false); // use OCSP for all tests. + properties.put("authenticator", AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name()); if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout));