Skip to content

Commit

Permalink
Add wiremock test
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Dec 2, 2024
1 parent a3dad01 commit 0982aaf
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 187 deletions.
10 changes: 5 additions & 5 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class SFLoginInput {
private boolean enableClientStoreTemporaryCredential;
private boolean enableClientRequestMfaToken;

//OAuth
// OAuth
private int redirectUriPort = -1;
private String clientId;
private String clientSecret;
Expand All @@ -64,7 +64,7 @@ public class SFLoginInput {
// Additional headers to add for Snowsight.
Map<String, String> additionalHttpHeadersForSnowsight;

SFLoginInput() {}
public SFLoginInput() {}

Duration getBrowserResponseTimeout() {
return browserResponseTimeout;
Expand All @@ -79,7 +79,7 @@ public String getServerUrl() {
return serverUrl;
}

SFLoginInput setServerUrl(String serverUrl) {
public SFLoginInput setServerUrl(String serverUrl) {
this.serverUrl = serverUrl;
return this;
}
Expand Down Expand Up @@ -247,7 +247,7 @@ public int getSocketTimeoutInMillis() {
return (int) socketTimeout.toMillis();
}

SFLoginInput setSocketTimeout(Duration socketTimeout) {
public SFLoginInput setSocketTimeout(Duration socketTimeout) {
this.socketTimeout = socketTimeout;
return this;
}
Expand Down Expand Up @@ -397,7 +397,7 @@ public HttpClientSettingsKey getHttpClientSettingsKey() {
return httpClientKey;
}

SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) {
public SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) {
this.httpClientKey = key;
return this;
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setBrowserResponseTimeout(browserResponseTimeout);

if (connectionPropertiesMap.containsKey(SFSessionProperty.OAUTH_REDIRECT_URI_PORT)) {
loginInput.setRedirectUriPort((Integer) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI_PORT));
loginInput.setRedirectUriPort(
(Integer) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI_PORT));
}

logger.info(
Expand Down
16 changes: 12 additions & 4 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) {
.equalsIgnoreCase(AuthenticatorType.EXTERNALBROWSER.name())) {
// SAML 2.0 compliant service/application
return AuthenticatorType.EXTERNALBROWSER;
} else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name())) {
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW.name())) {
// OAuth authorization code flow authentication
return AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW;
} else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH.name())) {
Expand Down Expand Up @@ -271,9 +273,15 @@ static SFLoginOutput openSession(
loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session");

if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW)) {
AssertUtil.assertTrue(loginInput.getClientId() != null, "passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
AssertUtil.assertTrue(loginInput.getClientSecret() != null, "passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
OauthAccessTokenProvider accessTokenProvider = new AuthorizationCodeFlowAccessTokenProvider();
AssertUtil.assertTrue(
loginInput.getClientId() != null,
"passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
AssertUtil.assertTrue(
loginInput.getClientSecret() != null,
"passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
OauthAccessTokenProvider accessTokenProvider =
new AuthorizationCodeFlowAccessTokenProvider(
new SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers());
String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput);
loginInput.setAuthenticator(AuthenticatorType.OAUTH.name());
loginInput.setToken(oauthAccessToken);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package net.snowflake.client.core.auth.oauth;

import static net.snowflake.client.core.SessionUtilExternalBrowser.*;

import com.amazonaws.util.StringUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
Expand All @@ -15,7 +17,17 @@
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.sun.net.httpserver.HttpServer;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import net.snowflake.client.core.HttpUtil;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SFLoginInput;
Expand All @@ -26,132 +38,154 @@
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.StringEntity;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static net.snowflake.client.core.SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers;

@SnowflakeJdbcInternalApi
public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider {

private static final SFLogger logger = SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class);

private static final String AUTHORIZE_ENDPOINT = "/oauth/authorize";
private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";

private static final String REDIRECT_URI_HOST = "localhost";
private static final int DEFAULT_REDIRECT_URI_PORT = 8001;
private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect";
public static final String SESSION_ROLE_SCOPE = "session:role";

public static int AUTHORIZE_REDIRECT_TIMEOUT_MINUTES = 2;

private final DefaultAuthExternalBrowserHandlers browserUtil = new DefaultAuthExternalBrowserHandlers();
private final ObjectMapper objectMapper = new ObjectMapper();

@Override
public String getAccessToken(SFLoginInput loginInput) throws SFException {
AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput);
return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode);
}

private AuthorizationCode requestAuthorizationCode(SFLoginInput loginInput) throws SFException {
try {
AuthorizationRequest request = buildAuthorizationRequest(loginInput);
URI authorizeRequestURI = request.toURI();
CompletableFuture<String> codeFuture = setupRedirectURIServerForAuthorizationCode(loginInput.getRedirectUriPort());
letUserAuthorizeViaBrowser(authorizeRequestURI);
String code = codeFuture.get(AUTHORIZE_REDIRECT_TIMEOUT_MINUTES, TimeUnit.MINUTES);
return new AuthorizationCode(code);
} catch (Exception e) {
if (e instanceof TimeoutException) {
logger.error("Authorization request timed out. Did not receive authorization code back to the redirect URI");
}
throw new RuntimeException(e.getMessage(), e);
}
private static final SFLogger logger =
SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class);

private static final String SNOWFLAKE_AUTHORIZE_ENDPOINT = "/oauth/authorize";
private static final String SNOWFLAKE_TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";

private static final String REDIRECT_URI_HOST = "localhost";
private static final int DEFAULT_REDIRECT_URI_PORT = 8001;
private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect";
public static final String SESSION_ROLE_SCOPE = "session:role";

public static int AUTHORIZE_REDIRECT_TIMEOUT_MINUTES = 2;

private final AuthExternalBrowserHandlers browserHandler;
private final ObjectMapper objectMapper = new ObjectMapper();

public AuthorizationCodeFlowAccessTokenProvider(AuthExternalBrowserHandlers browserHandler) {
this.browserHandler = browserHandler;
}

@Override
public String getAccessToken(SFLoginInput loginInput) throws SFException {
CodeVerifier pkceVerifier = new CodeVerifier();
AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput, pkceVerifier);
return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode, pkceVerifier);
}

private AuthorizationCode requestAuthorizationCode(
SFLoginInput loginInput, CodeVerifier pkceVerifier) throws SFException {
try {
AuthorizationRequest request = buildAuthorizationRequest(loginInput, pkceVerifier);
URI authorizeRequestURI = request.toURI();
CompletableFuture<String> codeFuture =
setupRedirectURIServerForAuthorizationCode(loginInput.getRedirectUriPort());
letUserAuthorizeViaBrowser(authorizeRequestURI);
String code = codeFuture.get(AUTHORIZE_REDIRECT_TIMEOUT_MINUTES, TimeUnit.MINUTES);
return new AuthorizationCode(code);
} catch (Exception e) {
if (e instanceof TimeoutException) {
logger.error(
"Authorization request timed out. Did not receive authorization code back to the redirect URI");
}
throw new RuntimeException(e);
}

private String exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException {
try {
TokenRequest request = buildTokenRequest(loginInput, authorizationCode);
String tokenResponse = HttpUtil.executeGeneralRequest(
convertTokenRequest(request.toHTTPRequest()),
loginInput.getLoginTimeout(),
loginInput.getAuthTimeout(),
loginInput.getSocketTimeoutInMillis(),
0,
loginInput.getHttpClientSettingsKey());
TokenResponseDTO tokenResponseDTO = objectMapper.readValue(tokenResponse, TokenResponseDTO.class);
return tokenResponseDTO.getAccessToken();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private String exchangeAuthorizationCodeForAccessToken(
SFLoginInput loginInput, AuthorizationCode authorizationCode, CodeVerifier pkceVerifier)
throws SFException {
try {
TokenRequest request = buildTokenRequest(loginInput, authorizationCode, pkceVerifier);
String tokenResponse =
HttpUtil.executeGeneralRequest(
convertTokenRequest(request.toHTTPRequest()),
loginInput.getLoginTimeout(),
loginInput.getAuthTimeout(),
loginInput.getSocketTimeoutInMillis(),
0,
loginInput.getHttpClientSettingsKey());
TokenResponseDTO tokenResponseDTO =
objectMapper.readValue(tokenResponse, TokenResponseDTO.class);
return tokenResponseDTO.getAccessToken();
} catch (Exception e) {
throw new RuntimeException(e);
}

private void letUserAuthorizeViaBrowser(URI authorizeRequestURI) throws SFException {
browserUtil.openBrowser(authorizeRequestURI.toString());
}

private static CompletableFuture<String> setupRedirectURIServerForAuthorizationCode(int redirectUriPort) throws IOException {
CompletableFuture<String> accessTokenFuture = new CompletableFuture<>();
int redirectPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT;
HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, redirectPort), 0);
httpServer.createContext(REDIRECT_URI_ENDPOINT, exchange -> {
String authorizationCode = extractAuthorizationCodeFromQueryParameters(exchange.getRequestURI().getQuery());
if (!StringUtils.isNullOrEmpty(authorizationCode)) {
accessTokenFuture.complete(authorizationCode);
httpServer.stop(0);
}
}

private void letUserAuthorizeViaBrowser(URI authorizeRequestURI) throws SFException {
browserHandler.openBrowser(authorizeRequestURI.toString());
}

private static CompletableFuture<String> setupRedirectURIServerForAuthorizationCode(
int redirectUriPort) throws IOException {
CompletableFuture<String> accessTokenFuture = new CompletableFuture<>();
int redirectPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT;
HttpServer httpServer =
HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, redirectPort), 0);
httpServer.createContext(
REDIRECT_URI_ENDPOINT,
exchange -> {
exchange.sendResponseHeaders(200, 0);
exchange.getResponseBody().close();
String authorizationCode =
extractAuthorizationCodeFromQueryParameters(exchange.getRequestURI().getQuery());
if (!StringUtils.isNullOrEmpty(authorizationCode)) {
accessTokenFuture.complete(authorizationCode);
httpServer.stop(0);
}
});
httpServer.start();
return accessTokenFuture;
}

private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput loginInput) throws URISyntaxException {
URI authorizeEndpoint = new URI(loginInput.getServerUrl() + AUTHORIZE_ENDPOINT);
ClientID clientID = new ClientID(loginInput.getClientId());
Scope scope = new Scope(String.format("%s:%s", SESSION_ROLE_SCOPE, loginInput.getRole()));
URI callback = buildRedirectURI(loginInput.getRedirectUriPort());
State state = new State(256);
return new AuthorizationRequest.Builder(
new ResponseType(ResponseType.Value.CODE), clientID)
.scope(scope)
.state(state)
.redirectionURI(callback)
.endpointURI(authorizeEndpoint)
.build();
}

private static TokenRequest buildTokenRequest(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws URISyntaxException {
URI callback = buildRedirectURI(loginInput.getRedirectUriPort());
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(authorizationCode, callback);
ClientAuthentication clientAuthentication = new ClientSecretBasic(new ClientID(loginInput.getClientId()), new Secret(loginInput.getClientSecret()));
URI tokenEndpoint = new URI(String.format(loginInput.getServerUrl() + TOKEN_REQUEST_ENDPOINT));
Scope scope = new Scope(SESSION_ROLE_SCOPE, loginInput.getRole());
return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, scope);
}

private static URI buildRedirectURI(int redirectUriPort) throws URISyntaxException {
redirectUriPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT;
return new URI(String.format("http://%s:%s%s", REDIRECT_URI_HOST, redirectUriPort, REDIRECT_URI_ENDPOINT));
}

private static String extractAuthorizationCodeFromQueryParameters(String queryParameters) {
String prefix = "code=";
String codeSuffix = queryParameters.substring(queryParameters.indexOf(prefix) + prefix.length());
return codeSuffix.substring(0, codeSuffix.indexOf("&"));
}

private static HttpRequestBase convertTokenRequest(HTTPRequest nimbusRequest) {
HttpPost request = new HttpPost(nimbusRequest.getURI());
request.setEntity(new StringEntity(nimbusRequest.getBody(), StandardCharsets.UTF_8));
nimbusRequest.getHeaderMap().forEach((key, values) -> request.addHeader(key, values.get(0)));
return request;
httpServer.start();
return accessTokenFuture;
}

private static AuthorizationRequest buildAuthorizationRequest(
SFLoginInput loginInput, CodeVerifier pkceVerifier) throws URISyntaxException {
URI authorizeEndpoint = new URI(loginInput.getServerUrl() + SNOWFLAKE_AUTHORIZE_ENDPOINT);
ClientID clientID = new ClientID(loginInput.getClientId());
Scope scope = new Scope(String.format("%s:%s", SESSION_ROLE_SCOPE, loginInput.getRole()));
URI callback = buildRedirectURI(loginInput.getRedirectUriPort());
State state = new State(256);
return new AuthorizationRequest.Builder(new ResponseType(ResponseType.Value.CODE), clientID)
.scope(scope)
.state(state)
.redirectionURI(callback)
.codeChallenge(pkceVerifier, CodeChallengeMethod.S256)
.endpointURI(authorizeEndpoint)
.build();
}

private static TokenRequest buildTokenRequest(
SFLoginInput loginInput, AuthorizationCode authorizationCode, CodeVerifier pkceVerifier)
throws URISyntaxException {
URI redirectURI = buildRedirectURI(loginInput.getRedirectUriPort());
AuthorizationGrant codeGrant =
new AuthorizationCodeGrant(authorizationCode, redirectURI, pkceVerifier);
ClientAuthentication clientAuthentication =
new ClientSecretBasic(
new ClientID(loginInput.getClientId()), new Secret(loginInput.getClientSecret()));
URI tokenEndpoint =
new URI(String.format(loginInput.getServerUrl() + SNOWFLAKE_TOKEN_REQUEST_ENDPOINT));
Scope scope = new Scope(SESSION_ROLE_SCOPE, loginInput.getRole());
return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, scope);
}

private static URI buildRedirectURI(int redirectUriPort) throws URISyntaxException {
redirectUriPort = (redirectUriPort != -1) ? redirectUriPort : DEFAULT_REDIRECT_URI_PORT;
return new URI(
String.format("http://%s:%s%s", REDIRECT_URI_HOST, redirectUriPort, REDIRECT_URI_ENDPOINT));
}

private static String extractAuthorizationCodeFromQueryParameters(String queryParameters) {
String prefix = "code=";
String codeSuffix =
queryParameters.substring(queryParameters.indexOf(prefix) + prefix.length());
if (codeSuffix.contains("&")) {
return codeSuffix.substring(0, codeSuffix.indexOf("&"));
} else {
return codeSuffix;
}
}

private static HttpRequestBase convertTokenRequest(HTTPRequest nimbusRequest) {
HttpPost request = new HttpPost(nimbusRequest.getURI());
request.setEntity(new StringEntity(nimbusRequest.getBody(), StandardCharsets.UTF_8));
nimbusRequest.getHeaderMap().forEach((key, values) -> request.addHeader(key, values.get(0)));
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
@SnowflakeJdbcInternalApi
public interface OauthAccessTokenProvider {

String getAccessToken(SFLoginInput loginInput) throws SFException;
String getAccessToken(SFLoginInput loginInput) throws SFException;
}
Loading

0 comments on commit 0982aaf

Please sign in to comment.