Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1831099: OAuth Client Credentials Flow Implementation #1993

Merged
merged 7 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions linkage-checker-exclusion-rules.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
<Source><Package name="com.nimbusds.jose"/></Source>
<Reason>Optional</Reason>
</LinkageError>
<LinkageError>
<Target><Package name="org.brotli.dec"/></Target>
<Source><Package name="org.apache.commons.compress.compressors"/></Source>
<Reason>Optional</Reason>
</LinkageError>
<LinkageError>
<Target><Package name="com.google.appengine.api.urlfetch"/></Target>
<Source><Package name="com.google.api.client.extensions.appengine"/></Source>
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/net/snowflake/client/core/AssertUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public class AssertUtil {
* @param internalErrorMesg The error message to display if condition is false
* @throws SFException Will be thrown if condition is false
*/
static void assertTrue(boolean condition, String internalErrorMesg) throws SFException {
@SnowflakeJdbcInternalApi
public static void assertTrue(boolean condition, String internalErrorMesg) throws SFException {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.

if (!condition) {
throw new SFException(ErrorCode.INTERNAL_ERROR, internalErrorMesg);
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,9 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
(String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_ID),
(String) connectionPropertiesMap.get(SFSessionProperty.CLIENT_SECRET),
(String) connectionPropertiesMap.get(SFSessionProperty.OAUTH_REDIRECT_URI),
(String) connectionPropertiesMap.get(SFSessionProperty.OAUTH_SCOPE),
(String) connectionPropertiesMap.get(SFSessionProperty.EXTERNAL_AUTHORIZATION_URL),
(String) connectionPropertiesMap.get(SFSessionProperty.EXTERNAL_TOKEN_REQUEST_URL));
(String) connectionPropertiesMap.get(SFSessionProperty.EXTERNAL_TOKEN_REQUEST_URL),
(String) connectionPropertiesMap.get(SFSessionProperty.OAUTH_SCOPE));

loginInput
.setServerUrl((String) connectionPropertiesMap.get(SFSessionProperty.SERVER_URL))
Expand Down
26 changes: 13 additions & 13 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import net.snowflake.client.core.auth.AuthenticatorType;
import net.snowflake.client.core.auth.ClientAuthnDTO;
import net.snowflake.client.core.auth.ClientAuthnParameter;
import net.snowflake.client.core.auth.oauth.AuthorizationCodeFlowAccessTokenProvider;
import net.snowflake.client.core.auth.oauth.OauthAccessTokenProvider;
import net.snowflake.client.core.auth.oauth.AccessTokenProvider;
import net.snowflake.client.core.auth.oauth.AccessTokenProviderFactory;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeDriver;
import net.snowflake.client.jdbc.SnowflakeReauthenticationRequest;
Expand Down Expand Up @@ -223,8 +223,11 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) {
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.OAUTH_AUTHORIZATION_CODE.name())) {
// OAuth authorization code flow authentication
return AuthenticatorType.OAUTH_AUTHORIZATION_CODE;
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS.name())) {
return AuthenticatorType.OAUTH_CLIENT_CREDENTIALS;
} else if (loginInput.getAuthenticator().equalsIgnoreCase(AuthenticatorType.OAUTH.name())) {
// OAuth access code Authentication
return AuthenticatorType.OAUTH;
Expand Down Expand Up @@ -273,17 +276,14 @@ static SFLoginOutput openSession(
AssertUtil.assertTrue(
loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session");

if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) {
AssertUtil.assertTrue(
loginInput.getOauthLoginInput().getClientId() != null,
"passing clientId is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
AssertUtil.assertTrue(
loginInput.getOauthLoginInput().getClientSecret() != null,
"passing clientSecret is required for OAUTH_AUTHORIZATION_CODE_FLOW authentication");
OauthAccessTokenProvider accessTokenProvider =
new AuthorizationCodeFlowAccessTokenProvider(
if (AccessTokenProviderFactory.isEligible(getAuthenticator(loginInput))) {
AccessTokenProviderFactory accessTokenProviderFactory =
new AccessTokenProviderFactory(
new SessionUtilExternalBrowser.DefaultAuthExternalBrowserHandlers(),
(int) loginInput.getBrowserResponseTimeout().getSeconds());
AccessTokenProvider accessTokenProvider =
accessTokenProviderFactory.createAccessTokenProvider(
getAuthenticator(loginInput), loginInput);
String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput);
loginInput.setAuthenticator(AuthenticatorType.OAUTH.name());
loginInput.setToken(oauthAccessToken);
Expand All @@ -295,7 +295,7 @@ static SFLoginOutput openSession(
AssertUtil.assertTrue(
loginInput.getUserName() != null, "missing user name for opening session");
} else {
// OAUTH needs either token or passord
// OAUTH needs either token or password
AssertUtil.assertTrue(
loginInput.getToken() != null || loginInput.getPassword() != null,
"missing token or password for opening session");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,10 @@ public enum AuthenticatorType {
/*
* Authorization code flow with browser popup
*/
OAUTH_AUTHORIZATION_CODE
OAUTH_AUTHORIZATION_CODE,

/*
* Client credentials flow with clientId and clientSecret as input
*/
OAUTH_CLIENT_CREDENTIALS
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.client.core.auth.oauth;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SFLoginInput;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;

@SnowflakeJdbcInternalApi
public interface OauthAccessTokenProvider {
public interface AccessTokenProvider {

String getAccessToken(SFLoginInput loginInput) throws SFException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.client.core.auth.oauth;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import net.snowflake.client.core.AssertUtil;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SFLoginInput;
import net.snowflake.client.core.SessionUtilExternalBrowser;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.core.auth.AuthenticatorType;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class AccessTokenProviderFactory {

private static final SFLogger logger =
SFLoggerFactory.getLogger(AccessTokenProviderFactory.class);
private static final Set<AuthenticatorType> ELIGIBLE_AUTH_TYPES = new HashSet<>(Arrays.asList(AuthenticatorType.OAUTH_AUTHORIZATION_CODE, AuthenticatorType.OAUTH_CLIENT_CREDENTIALS));

private final SessionUtilExternalBrowser.AuthExternalBrowserHandlers browserHandler;
private final int browserAuthorizationTimeoutSeconds;

public AccessTokenProviderFactory(
SessionUtilExternalBrowser.AuthExternalBrowserHandlers browserHandler,
int browserAuthorizationTimeoutSeconds) {
this.browserHandler = browserHandler;
this.browserAuthorizationTimeoutSeconds = browserAuthorizationTimeoutSeconds;
}

public AccessTokenProvider createAccessTokenProvider(
AuthenticatorType authenticatorType, SFLoginInput loginInput) throws SFException {
switch (authenticatorType) {
case OAUTH_AUTHORIZATION_CODE:
assertContainsClientCredentials(loginInput, authenticatorType);
return new OAuthAuthorizationCodeAccessTokenProvider(
browserHandler, browserAuthorizationTimeoutSeconds);
case OAUTH_CLIENT_CREDENTIALS:
assertContainsClientCredentials(loginInput, authenticatorType);
AssertUtil.assertTrue(
loginInput.getOauthLoginInput().getExternalTokenRequestUrl() != null,
"passing externalTokenRequestUrl is required for OAUTH_CLIENT_CREDENTIALS authentication");
return new OAuthClientCredentialsAccessTokenProvider();
default:
logger.error("Unsupported authenticator type: " + authenticatorType);
throw new SFException(ErrorCode.INTERNAL_ERROR);
}
}

public static Set<AuthenticatorType> getEligible() {
return ELIGIBLE_AUTH_TYPES;
}

public static boolean isEligible(AuthenticatorType authenticatorType) {
return getEligible().contains(authenticatorType);
}

private void assertContainsClientCredentials(
SFLoginInput loginInput, AuthenticatorType authenticatorType) throws SFException {
AssertUtil.assertTrue(
loginInput.getOauthLoginInput().getClientId() != null,
String.format(
"passing clientId is required for %s authentication", authenticatorType.name()));
AssertUtil.assertTrue(
loginInput.getOauthLoginInput().getClientSecret() != null,
String.format(
"passing clientSecret is required for %s authentication", authenticatorType.name()));
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

package net.snowflake.client.core.auth.oauth;

import static net.snowflake.client.core.SessionUtilExternalBrowser.AuthExternalBrowserHandlers;
Expand All @@ -14,7 +18,6 @@
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
Expand All @@ -38,30 +41,23 @@
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;
import org.apache.http.NameValuePair;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.utils.URLEncodedUtils;
import org.apache.http.entity.StringEntity;

@SnowflakeJdbcInternalApi
public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider {
public class OAuthAuthorizationCodeAccessTokenProvider implements AccessTokenProvider {

private static final SFLogger logger =
SFLoggerFactory.getLogger(AuthorizationCodeFlowAccessTokenProvider.class);

private static final String SNOWFLAKE_AUTHORIZE_ENDPOINT = "/oauth/authorize";
private static final String SNOWFLAKE_TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";
SFLoggerFactory.getLogger(OAuthAuthorizationCodeAccessTokenProvider.class);

private static final String DEFAULT_REDIRECT_HOST = "http://localhost:8001";
private static final String REDIRECT_URI_ENDPOINT = "/snowflake/oauth-redirect";
private static final String DEFAULT_REDIRECT_URI = DEFAULT_REDIRECT_HOST + REDIRECT_URI_ENDPOINT;
public static final String DEFAULT_SESSION_ROLE_SCOPE_PREFIX = "session:role:";

private final AuthExternalBrowserHandlers browserHandler;
private final ObjectMapper objectMapper = new ObjectMapper();
private final int browserAuthorizationTimeoutSeconds;

public AuthorizationCodeFlowAccessTokenProvider(
public OAuthAuthorizationCodeAccessTokenProvider(
AuthExternalBrowserHandlers browserHandler, int browserAuthorizationTimeoutSeconds) {
this.browserHandler = browserHandler;
this.browserAuthorizationTimeoutSeconds = browserAuthorizationTimeoutSeconds;
Expand All @@ -70,11 +66,14 @@ public AuthorizationCodeFlowAccessTokenProvider(
@Override
public String getAccessToken(SFLoginInput loginInput) throws SFException {
try {
logger.debug("Starting OAuth authorization code authentication flow...");
CodeVerifier pkceVerifier = new CodeVerifier();
AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput, pkceVerifier);
return exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode, pkceVerifier);
} catch (Exception e) {
logger.error("Error during OAuth authorization code flow", e);
logger.error(
"Error during OAuth authorization code flow. Verify configuration passed to driver and IdP (URLs, grant types, scope, etc.)",
e);
throw new SFException(e, ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, e.getMessage());
}
}
Expand All @@ -98,10 +97,11 @@ private String exchangeAuthorizationCodeForAccessToken(
TokenRequest request = buildTokenRequest(loginInput, authorizationCode, pkceVerifier);
URI requestUri = request.getEndpointURI();
logger.debug(
"Requesting access token from: {}", requestUri.getAuthority() + requestUri.getPath());
"Requesting OAuth access token from: {}",
requestUri.getAuthority() + requestUri.getPath());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also let's use parameters without concatenation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

String tokenResponse =
HttpUtil.executeGeneralRequest(
convertToBaseRequest(request.toHTTPRequest()),
OAuthUtil.convertToBaseRequest(request.toHTTPRequest()),
loginInput.getLoginTimeout(),
loginInput.getAuthTimeout(),
loginInput.getSocketTimeoutInMillis(),
Expand All @@ -110,7 +110,9 @@ private String exchangeAuthorizationCodeForAccessToken(
TokenResponseDTO tokenResponseDTO =
objectMapper.readValue(tokenResponse, TokenResponseDTO.class);
logger.debug(
"Received OAuth access token from: {}", requestUri.getAuthority() + requestUri.getPath());
"Received OAuth access token from: {}{}",
requestUri.getAuthority(),
requestUri.getPath());
return tokenResponseDTO.getAccessToken();
} catch (Exception e) {
logger.error("Error during making OAuth access token request", e);
Expand All @@ -128,13 +130,12 @@ private AuthorizationCode letUserAuthorize(
browserHandler.openBrowser(authorizeRequestURI.toString());
String code = codeFuture.get(this.browserAuthorizationTimeoutSeconds, TimeUnit.SECONDS);
return new AuthorizationCode(code);
} catch (TimeoutException e) {
throw new SFException(
e,
ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR,
"Authorization request timed out. Snowflake driver did not receive authorization code back to the redirect URI. Verify your security integration and driver configuration.");
} catch (Exception e) {
if (e instanceof TimeoutException) {
throw new SFException(
e,
ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR,
"Authorization request timed out. Snowflake driver did not receive authorization code back to the redirect URI. Verify your security integration and driver configuration.");
}
throw new SFException(e, ErrorCode.OAUTH_AUTHORIZATION_CODE_FLOW_ERROR, e.getMessage());
} finally {
logger.debug("Stopping OAuth redirect URI server @ {}", httpServer.getAddress());
Expand Down Expand Up @@ -185,14 +186,15 @@ private static AuthorizationRequest buildAuthorizationRequest(
ClientID clientID = new ClientID(oauthLoginInput.getClientId());
URI callback = buildRedirectUri(oauthLoginInput);
State state = new State(256);
String scope = getScope(loginInput);
String scope = OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole());
return new AuthorizationRequest.Builder(new ResponseType(ResponseType.Value.CODE), clientID)
.scope(new Scope(scope))
.state(state)
.redirectionURI(callback)
.codeChallenge(pkceVerifier, CodeChallengeMethod.S256)
.endpointURI(
getAuthorizationUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()))
OAuthUtil.getAuthorizationUrl(
loginInput.getOauthLoginInput(), loginInput.getServerUrl()))
.build();
}

Expand All @@ -205,9 +207,10 @@ private static TokenRequest buildTokenRequest(
new ClientSecretBasic(
new ClientID(loginInput.getOauthLoginInput().getClientId()),
new Secret(loginInput.getOauthLoginInput().getClientSecret()));
Scope scope = new Scope(getScope(loginInput));
Scope scope =
new Scope(OAuthUtil.getScope(loginInput.getOauthLoginInput(), loginInput.getRole()));
return new TokenRequest(
getTokenRequestUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()),
OAuthUtil.getTokenRequestUrl(loginInput.getOauthLoginInput(), loginInput.getServerUrl()),
clientAuthentication,
codeGrant,
scope);
Expand All @@ -220,29 +223,4 @@ private static URI buildRedirectUri(SFOauthLoginInput oauthLoginInput) {
: DEFAULT_REDIRECT_URI;
return URI.create(redirectUri);
}

private static HttpRequestBase convertToBaseRequest(HTTPRequest request) {
HttpPost baseRequest = new HttpPost(request.getURI());
baseRequest.setEntity(new StringEntity(request.getBody(), StandardCharsets.UTF_8));
request.getHeaderMap().forEach((key, values) -> baseRequest.addHeader(key, values.get(0)));
return baseRequest;
}

private static URI getAuthorizationUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) {
return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalAuthorizationUrl())
? URI.create(oauthLoginInput.getExternalAuthorizationUrl())
: URI.create(serverUrl + SNOWFLAKE_AUTHORIZE_ENDPOINT);
}

private static URI getTokenRequestUrl(SFOauthLoginInput oauthLoginInput, String serverUrl) {
return !StringUtils.isNullOrEmpty(oauthLoginInput.getExternalTokenRequestUrl())
? URI.create(oauthLoginInput.getExternalTokenRequestUrl())
: URI.create(serverUrl + SNOWFLAKE_TOKEN_REQUEST_ENDPOINT);
}

private static String getScope(SFLoginInput loginInput) {
return (!StringUtils.isNullOrEmpty(loginInput.getOauthLoginInput().getScope()))
? loginInput.getOauthLoginInput().getScope()
: DEFAULT_SESSION_ROLE_SCOPE_PREFIX + loginInput.getRole();
}
}
Loading
Loading