Skip to content

Commit

Permalink
Implement full flow
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Dec 2, 2024
1 parent 195ff36 commit a3dad01
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 48 deletions.
32 changes: 32 additions & 0 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public class SFLoginInput {
private boolean enableClientStoreTemporaryCredential;
private boolean enableClientRequestMfaToken;

//OAuth
private int redirectUriPort = -1;
private String clientId;
private String clientSecret;

private Duration browserResponseTimeout;

// Additional headers to add for Snowsight.
Expand Down Expand Up @@ -417,6 +422,33 @@ SFLoginInput setDisableSamlURLCheck(boolean disableSamlURLCheck) {
return this;
}

public int getRedirectUriPort() {
return redirectUriPort;
}

public SFLoginInput setRedirectUriPort(int redirectUriPort) {
this.redirectUriPort = redirectUriPort;
return this;
}

public String getClientId() {
return clientId;
}

public SFLoginInput setClientId(String clientId) {
this.clientId = clientId;
return this;
}

public String getClientSecret() {
return clientSecret;
}

public SFLoginInput setClientSecret(String clientSecret) {
this.clientSecret = clientSecret;
return this;
}

Map<String, String> getAdditionalHttpHeadersForSnowsight() {
return additionalHttpHeadersForSnowsight;
}
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setSessionParameters(sessionParametersMap)
.setPrivateKey((PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY))
.setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE))
.setClientId((String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_ID))
.setClientSecret((String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_SECRET))
.setPrivateKeyBase64(
(String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_BASE64))
.setPrivateKeyPwd(
Expand All @@ -696,6 +698,10 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setEnableClientRequestMfaToken(enableClientRequestMfaToken)
.setBrowserResponseTimeout(browserResponseTimeout);

if (connectionPropertiesMap.containsKey(SFSessionProperty.OAUTH_REDIRECT_URI_PORT)) {
loginInput.setRedirectUriPort((Integer) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI_PORT));
}

logger.info(
"Connecting to {} Snowflake domain",
loginInput.getHostFromServerUrl().toLowerCase().endsWith(".cn") ? "CHINA" : "GLOBAL");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ public enum SFSessionProperty {
AUTHENTICATOR("authenticator", false, String.class),
OKTA_USERNAME("oktausername", false, String.class),
PRIVATE_KEY("privateKey", false, PrivateKey.class),
OAUTH_REDIRECT_URI_PORT("oauthRedirectUriPort", false, Integer.class),
CLIENT_ID("clientID", false, String.class),
CLIENT_SECRET("clientSecret", false, String.class),
WAREHOUSE("warehouse", false, String.class),
LOGIN_TIMEOUT("loginTimeout", false, Integer.class),
NETWORK_TIMEOUT("networkTimeout", false, Integer.class),
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,11 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) {
.equalsIgnoreCase(AuthenticatorType.EXTERNALBROWSER.name())) {
// SAML 2.0 compliant service/application
return AuthenticatorType.EXTERNALBROWSER;
} else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name())) {
// OAuth authorization code flow authentication
return AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW;
} else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH.name())) {
// OAuth Authentication
// OAuth access code Authentication
return AuthenticatorType.OAUTH;
} else if (loginInput
.getAuthenticator()
Expand Down Expand Up @@ -268,6 +271,8 @@ static SFLoginOutput openSession(
loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session");

if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW)) {
AssertUtil.assertTrue(loginInput.getClientId() != null, "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
AssertUtil.assertTrue(loginInput.getClientSecret() != null, "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
OauthAccessTokenProvider accessTokenProvider = new AuthorizationCodeFlowAccessTokenProvider();
String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput);
loginInput.setAuthenticator(AuthenticatorType.OAUTH.name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public interface AuthExternalBrowserHandlers {
void output(String msg);
}

static class DefaultAuthExternalBrowserHandlers implements AuthExternalBrowserHandlers {
public static class DefaultAuthExternalBrowserHandlers implements AuthExternalBrowserHandlers {
@Override
public HttpPost build(URI uri) {
return new HttpPost(uri);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,97 +1,123 @@
package net.snowflake.client.core.auth.oauth;

import com.nimbusds.oauth2.sdk.AccessTokenResponse;
import com.amazonaws.util.StringUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.AuthorizationRequest;
import com.nimbusds.oauth2.sdk.ResponseType;
import com.nimbusds.oauth2.sdk.Scope;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.sun.net.httpserver.HttpServer;
import net.snowflake.client.core.HttpUtil;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SFLoginInput;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import org.apache.http.client.methods.HttpGet;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.StringEntity;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static net.snowflake.client.core.SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers;

@SnowflakeJdbcInternalApi
public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider {

private static final SFLogger logger = SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class);

private static final String AUTHORIZE_ENDPOINT = "/oauth/authorize";
private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";

private static final String REDIRECT_URI_HOST = "localhost";
private static final int REDIRECT_URI_PORT = 8001;
private static final String REDIRECT_URI_PATH = "/oauth-redirect";
private static final int DEFAULT_REDIRECT_URI_PORT = 8001;
private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect";
public static final String SESSION_ROLE_SCOPE = "session:role";

public static int AUTHORIZE_REDIRECT_TIMEOUT_MINUTES = 2;

private final DefaultAuthExternalBrowserHandlers browserUtil = new DefaultAuthExternalBrowserHandlers();
private final ObjectMapper objectMapper = new ObjectMapper();

@Override
public String getAccessToken(SFLoginInput loginInput) throws SFException {
AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput);
AccessToken accessToken = exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode);
return accessToken.getValue();
return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode);
}

private AuthorizationCode requestAuthorizationCode(SFLoginInput loginInput) throws SFException {
try {
AuthorizationRequest request = buildAuthorizationRequest(loginInput);
URI requestURI = request.toURI();
HttpUtil.executeGeneralRequest(new HttpGet(requestURI),
loginInput.getLoginTimeout(),
loginInput.getAuthTimeout(),
loginInput.getSocketTimeoutInMillis(),
0,
loginInput.getHttpClientSettingsKey());
String code = getAuthorizationCodeFromRedirectURI().join();
URI authorizeRequestURI = request.toURI();
CompletableFuture<String> codeFuture = setupRedirectURIServerForAuthorizationCode(loginInput.getRedirectUriPort());
letUserAuthorizeViaBrowser(authorizeRequestURI);
String code = codeFuture.get(AUTHORIZE_REDIRECT_TIMEOUT_MINUTES, TimeUnit.MINUTES);
return new AuthorizationCode(code);
} catch (Exception e) {
throw new SFException(e, ErrorCode.INTERNAL_ERROR);
if (e instanceof TimeoutException) {
logger.error("Authorization request timed out. Did not receive authorization code back to the redirect URI");
}
throw new RuntimeException(e.getMessage(), e);
}
}

private static AccessToken exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException {
private String exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException {
try {
TokenRequest request = buildTokenRequest(loginInput, authorizationCode);
TokenResponse response = TokenResponse.parse(request.toHTTPRequest().send());
if (!response.indicatesSuccess()) {
TokenErrorResponse errorResponse = response.toErrorResponse();
errorResponse.getErrorObject();
}
AccessTokenResponse successResponse = response.toSuccessResponse();
return successResponse.getTokens().getAccessToken();
String tokenResponse = HttpUtil.executeGeneralRequest(
convertTokenRequest(request.toHTTPRequest()),
loginInput.getLoginTimeout(),
loginInput.getAuthTimeout(),
loginInput.getSocketTimeoutInMillis(),
0,
loginInput.getHttpClientSettingsKey());
TokenResponseDTO tokenResponseDTO = objectMapper.readValue(tokenResponse, TokenResponseDTO.class);
return tokenResponseDTO.getAccessToken();
} catch (Exception e) {
throw new SFException(e, ErrorCode.INTERNAL_ERROR);
throw new RuntimeException(e);
}
}

private static CompletableFuture<String> getAuthorizationCodeFromRedirectURI() throws IOException {
private void letUserAuthorizeViaBrowser(URI authorizeRequestURI) throws SFException {
browserUtil.openBrowser(authorizeRequestURI.toString());
}

private static CompletableFuture<String> setupRedirectURIServerForAuthorizationCode(int redirectUriPort) throws IOException {
CompletableFuture<String> accessTokenFuture = new CompletableFuture<>();
HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, REDIRECT_URI_PORT), 0);
httpServer.createContext(REDIRECT_URI_PATH, exchange -> {
String authorizationCode = exchange.getRequestURI().getQuery();
accessTokenFuture.complete(authorizationCode);
httpServer.stop(0);
int redirectPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT;
HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, redirectPort), 0);
httpServer.createContext(REDIRECT_URI_ENDPOINT, exchange -> {
String authorizationCode = extractAuthorizationCodeFromQueryParameters(exchange.getRequestURI().getQuery());
if (!StringUtils.isNullOrEmpty(authorizationCode)) {
accessTokenFuture.complete(authorizationCode);
httpServer.stop(0);
}
});
httpServer.start();
return accessTokenFuture;
}

private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput loginInput) throws URISyntaxException {
URI authorizeEndpoint = new URI(String.format("%s/oauth/authorize", loginInput.getServerUrl()));
ClientID clientID = new ClientID("123");
Scope scope = new Scope(String.format("session:role:%s", loginInput.getRole()));
URI callback = buildRedirectURI();
URI authorizeEndpoint = new URI(loginInput.getServerUrl() + AUTHORIZE_ENDPOINT);
ClientID clientID = new ClientID(loginInput.getClientId());
Scope scope = new Scope(String.format("%s:%s", SESSION_ROLE_SCOPE, loginInput.getRole()));
URI callback = buildRedirectURI(loginInput.getRedirectUriPort());
State state = new State(256);
return new AuthorizationRequest.Builder(
new ResponseType(ResponseType.Value.CODE), clientID)
Expand All @@ -102,16 +128,30 @@ private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput login
.build();
}

private static URI buildRedirectURI() throws URISyntaxException {
return new URI(String.format("https://%s:%s%s", REDIRECT_URI_HOST, REDIRECT_URI_PORT, REDIRECT_URI_PATH));
}

private static TokenRequest buildTokenRequest(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws URISyntaxException {
URI callback = buildRedirectURI();
URI callback = buildRedirectURI(loginInput.getRedirectUriPort());
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(authorizationCode, callback);
ClientAuthentication clientAuthentication = new ClientSecretBasic(new ClientID("123"), new Secret("123"));
URI tokenEndpoint = new URI(String.format("%s/oauth/token-request", loginInput.getServerUrl()));
Scope scope = new Scope("session:role", loginInput.getRole());
ClientAuthentication clientAuthentication = new ClientSecretBasic(new ClientID(loginInput.getClientId()), new Secret(loginInput.getClientSecret()));
URI tokenEndpoint = new URI(String.format(loginInput.getServerUrl() + TOKEN_REQUEST_ENDPOINT));
Scope scope = new Scope(SESSION_ROLE_SCOPE, loginInput.getRole());
return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, scope);
}

private static URI buildRedirectURI(int redirectUriPort) throws URISyntaxException {
redirectUriPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT;
return new URI(String.format("http://%s:%s%s", REDIRECT_URI_HOST, redirectUriPort, REDIRECT_URI_ENDPOINT));
}

private static String extractAuthorizationCodeFromQueryParameters(String queryParameters) {
String prefix = "code=";
String codeSuffix = queryParameters.substring(queryParameters.indexOf(prefix) + prefix.length());
return codeSuffix.substring(0, codeSuffix.indexOf("&"));
}

private static HttpRequestBase convertTokenRequest(HTTPRequest nimbusRequest) {
HttpPost request = new HttpPost(nimbusRequest.getURI());
request.setEntity(new StringEntity(nimbusRequest.getBody(), StandardCharsets.UTF_8));
nimbusRequest.getHeaderMap().forEach((key, values) -> request.addHeader(key, values.get(0)));
return request;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package net.snowflake.client.core.auth.oauth;


import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;

class TokenResponseDTO {

private final String accessToken;
private final String refreshToken;
private final String tokenType;
private final String scope;
private final String username;
private final boolean idpInitiated;
private final long expiresIn;
private final long refreshTokenExpiresIn;

@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public TokenResponseDTO(@JsonProperty("access_token") String accessToken,
@JsonProperty("refresh_token") String refreshToken,
@JsonProperty("token_type") String tokenType,
@JsonProperty("scope") String scope,
@JsonProperty("username") String username,
@JsonProperty("idp_initiated") boolean idpInitiated,
@JsonProperty("expires_in") long expiresIn,
@JsonProperty("refresh_token_expires_in") long refreshTokenExpiresIn) {
this.accessToken = accessToken;
this.tokenType = tokenType;
this.refreshToken = refreshToken;
this.scope = scope;
this.username = username;
this.idpInitiated = idpInitiated;
this.expiresIn = expiresIn;
this.refreshTokenExpiresIn = refreshTokenExpiresIn;
}

public String getAccessToken() {
return accessToken;
}

public String getTokenType() {
return tokenType;
}

public String getRefreshToken() {
return refreshToken;
}

public String getScope() {
return scope;
}

public long getExpiresIn() {
return expiresIn;
}

public String getUsername() {
return username;
}

public long getRefreshTokenExpiresIn() {
return refreshTokenExpiresIn;
}

public boolean isIdpInitiated() {
return idpInitiated;
}
}
3 changes: 3 additions & 0 deletions src/test/java/net/snowflake/client/AbstractDriverIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import static org.hamcrest.MatcherAssert.assertThat;

import com.google.common.base.Strings;
import net.snowflake.client.core.auth.AuthenticatorType;

import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
Expand Down Expand Up @@ -323,6 +325,7 @@ public static Connection getConnection(

properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this?
properties.put("insecureMode", false); // use OCSP for all tests.
properties.put("authenticator", AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name());

if (injectSocketTimeout > 0) {
properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout));
Expand Down

0 comments on commit a3dad01

Please sign in to comment.