diff --git a/parent-pom.xml b/parent-pom.xml index b1742d64e..40a3656dd 100644 --- a/parent-pom.xml +++ b/parent-pom.xml @@ -70,6 +70,7 @@ 4.11.0 4.1.115.Final 9.37.3 + 11.20.1 0.31.1 1.0-alpha-9-stable-1 3.4.2 @@ -218,6 +219,11 @@ nimbus-jose-jwt ${nimbusds.version} + + com.nimbusds + oauth2-oidc-sdk + ${nimbusds.oauth2.version} + com.yammer.metrics metrics-core @@ -645,6 +651,10 @@ com.nimbusds nimbus-jose-jwt + + com.nimbusds + oauth2-oidc-sdk + com.yammer.metrics metrics-core diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index 5f52b64af..a6c6afd6d 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -160,7 +160,7 @@ public SFLoginInput setAccountName(String accountName) { return this; } - int getLoginTimeout() { + public int getLoginTimeout() { return loginTimeout; } @@ -184,7 +184,7 @@ SFLoginInput setRetryTimeout(int retryTimeout) { return this; } - int getAuthTimeout() { + public int getAuthTimeout() { return authTimeout; } @@ -238,7 +238,7 @@ SFLoginInput setConnectionTimeout(Duration connectionTimeout) { return this; } - int getSocketTimeoutInMillis() { + public int getSocketTimeoutInMillis() { return (int) socketTimeout.toMillis(); } @@ -388,7 +388,7 @@ SFLoginInput setOCSPMode(OCSPMode ocspMode) { return this; } - HttpClientSettingsKey getHttpClientSettingsKey() { + public HttpClientSettingsKey getHttpClientSettingsKey() { return httpClientKey; } diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index e13c21162..17971ad37 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -28,6 +28,8 @@ import net.snowflake.client.core.auth.AuthenticatorType; import net.snowflake.client.core.auth.ClientAuthnDTO; import net.snowflake.client.core.auth.ClientAuthnParameter; +import net.snowflake.client.core.auth.oauth.AuthorizationCodeFlowAccessTokenProvider; +import net.snowflake.client.core.auth.oauth.OauthAccessTokenProvider; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.SnowflakeDriver; import net.snowflake.client.jdbc.SnowflakeReauthenticationRequest; @@ -265,6 +267,13 @@ static SFLoginOutput openSession( AssertUtil.assertTrue( loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session"); + if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW)) { + OauthAccessTokenProvider accessTokenProvider = new AuthorizationCodeFlowAccessTokenProvider(); + String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput); + loginInput.setAuthenticator(AuthenticatorType.OAUTH.name()); + loginInput.setToken(oauthAccessToken); + } + final AuthenticatorType authenticator = getAuthenticator(loginInput); if (!authenticator.equals(AuthenticatorType.OAUTH)) { // OAuth does not require a username diff --git a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java index e25af718a..32d28dd83 100644 --- a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java +++ b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java @@ -41,5 +41,10 @@ public enum AuthenticatorType { /* * Authenticator to enable token for regular login with mfa */ - USERNAME_PASSWORD_MFA + USERNAME_PASSWORD_MFA, + + /* + * Authorization code flow with browser popup + */ + OAUTH_AUTHORIZATION_CODE_FLOW } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java new file mode 100644 index 000000000..9f4406e3a --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/AuthorizationCodeFlowAccessTokenProvider.java @@ -0,0 +1,120 @@ +package net.snowflake.client.core.auth.oauth; + +import com.nimbusds.oauth2.sdk.AccessTokenResponse; +import com.nimbusds.oauth2.sdk.AuthorizationCode; +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.AuthorizationRequest; +import com.nimbusds.oauth2.sdk.ResponseType; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.TokenErrorResponse; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.token.AccessToken; +import com.sun.net.httpserver.HttpServer; +import net.snowflake.client.core.HttpUtil; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SFLoginInput; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.ErrorCode; +import org.apache.http.client.methods.HttpGet; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.concurrent.CompletableFuture; + +@SnowflakeJdbcInternalApi +public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider { + + private static final String REDIRECT_URI_HOST = "localhost"; + private static final int REDIRECT_URI_PORT = 8001; + private static final String REDIRECT_URI_PATH = "/oauth-redirect"; + + @Override + public String getAccessToken(SFLoginInput loginInput) throws SFException { + AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput); + AccessToken accessToken = exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode); + return accessToken.getValue(); + } + + private AuthorizationCode requestAuthorizationCode(SFLoginInput loginInput) throws SFException { + try { + AuthorizationRequest request = buildAuthorizationRequest(loginInput); + + URI requestURI = request.toURI(); + HttpUtil.executeGeneralRequest(new HttpGet(requestURI), + loginInput.getLoginTimeout(), + loginInput.getAuthTimeout(), + loginInput.getSocketTimeoutInMillis(), + 0, + loginInput.getHttpClientSettingsKey()); + CompletableFuture f = getAuthorizationCodeFromRedirectURI(); + f.join(); + return new AuthorizationCode(f.get()); + } catch (Exception e) { + throw new SFException(e, ErrorCode.INTERNAL_ERROR); + } + } + + private static AccessToken exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException { + try { + TokenRequest request = buildTokenRequest(loginInput, authorizationCode); + TokenResponse response = TokenResponse.parse(request.toHTTPRequest().send()); + if (!response.indicatesSuccess()) { + TokenErrorResponse errorResponse = response.toErrorResponse(); + errorResponse.getErrorObject(); + } + AccessTokenResponse successResponse = response.toSuccessResponse(); + return successResponse.getTokens().getAccessToken(); + } catch (Exception e) { + throw new SFException(e, ErrorCode.INTERNAL_ERROR); + } + } + + private static CompletableFuture getAuthorizationCodeFromRedirectURI() throws IOException { + CompletableFuture accessTokenFuture = new CompletableFuture<>(); + HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, REDIRECT_URI_PORT), 0); + httpServer.createContext(REDIRECT_URI_PATH, exchange -> { + String authorizationCode = exchange.getRequestURI().getQuery(); + accessTokenFuture.complete(authorizationCode); + httpServer.stop(0); + }); + return accessTokenFuture; + } + + private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput loginInput) throws URISyntaxException { + URI authorizeEndpoint = new URI(String.format("%s/oauth/authorize", loginInput.getServerUrl())); + ClientID clientID = new ClientID("123"); + Scope scope = new Scope("read", "write"); + URI callback = buildRedirectURI(); + State state = new State(); + return new AuthorizationRequest.Builder( + new ResponseType(ResponseType.Value.CODE), clientID) + .scope(scope) + .state(state) + .redirectionURI(callback) + .endpointURI(authorizeEndpoint) + .build(); + } + + private static URI buildRedirectURI() throws URISyntaxException { + return new URI(String.format("https://%s:%s%s", REDIRECT_URI_HOST, REDIRECT_URI_PORT, REDIRECT_URI_PATH)); + } + + private static TokenRequest buildTokenRequest(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws URISyntaxException { + URI callback = buildRedirectURI(); + AuthorizationGrant codeGrant = new AuthorizationCodeGrant(authorizationCode, callback); + ClientID clientID = new ClientID("123"); + Secret clientSecret = new Secret("123"); + ClientAuthentication clientAuthentication = new ClientSecretBasic(clientID, clientSecret); + URI tokenEndpoint = new URI(String.format("%s/oauth/token", loginInput.getServerUrl())); + return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, new Scope()); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java b/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java new file mode 100644 index 000000000..05e9dacc5 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OauthAccessTokenProvider.java @@ -0,0 +1,9 @@ +package net.snowflake.client.core.auth.oauth; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SFLoginInput; + +public interface OauthAccessTokenProvider { + + String getAccessToken(SFLoginInput loginInput) throws SFException; +} diff --git a/thin_public_pom.xml b/thin_public_pom.xml index 09c6bf079..eeb42d0f0 100644 --- a/thin_public_pom.xml +++ b/thin_public_pom.xml @@ -194,6 +194,11 @@ nimbus-jose-jwt ${nimbusds.version} + + com.nimbusds + oauth2-oidc-sdk + ${nimbusds.oauth2.version} + com.yammer.metrics metrics-core