Skip to content

Commit

Permalink
Draft OAuth authz code flow
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dheyman committed Nov 30, 2024
1 parent c62c5e4 commit 1741680
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 5 deletions.
10 changes: 10 additions & 0 deletions parent-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
<mockito.version>4.11.0</mockito.version>
<netty.version>4.1.115.Final</netty.version>
<nimbusds.version>9.37.3</nimbusds.version>
<nimbusds.oauth2.version>11.20.1</nimbusds.oauth2.version>
<opencensus.version>0.31.1</opencensus.version>
<plexus.container.version>1.0-alpha-9-stable-1</plexus.container.version>
<plexus.utils.version>3.4.2</plexus.utils.version>
Expand Down Expand Up @@ -218,6 +219,11 @@
<artifactId>nimbus-jose-jwt</artifactId>
<version>${nimbusds.version}</version>
</dependency>
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>oauth2-oidc-sdk</artifactId>
<version>${nimbusds.oauth2.version}</version>
</dependency>
<dependency>
<groupId>com.yammer.metrics</groupId>
<artifactId>metrics-core</artifactId>
Expand Down Expand Up @@ -645,6 +651,10 @@
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
</dependency>
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>oauth2-oidc-sdk</artifactId>
</dependency>
<dependency>
<groupId>com.yammer.metrics</groupId>
<artifactId>metrics-core</artifactId>
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public SFLoginInput setAccountName(String accountName) {
return this;
}

int getLoginTimeout() {
public int getLoginTimeout() {
return loginTimeout;
}

Expand All @@ -184,7 +184,7 @@ SFLoginInput setRetryTimeout(int retryTimeout) {
return this;
}

int getAuthTimeout() {
public int getAuthTimeout() {
return authTimeout;
}

Expand Down Expand Up @@ -238,7 +238,7 @@ SFLoginInput setConnectionTimeout(Duration connectionTimeout) {
return this;
}

int getSocketTimeoutInMillis() {
public int getSocketTimeoutInMillis() {
return (int) socketTimeout.toMillis();
}

Expand Down Expand Up @@ -388,7 +388,7 @@ SFLoginInput setOCSPMode(OCSPMode ocspMode) {
return this;
}

HttpClientSettingsKey getHttpClientSettingsKey() {
public HttpClientSettingsKey getHttpClientSettingsKey() {
return httpClientKey;
}

Expand Down
9 changes: 9 additions & 0 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import net.snowflake.client.core.auth.AuthenticatorType;
import net.snowflake.client.core.auth.ClientAuthnDTO;
import net.snowflake.client.core.auth.ClientAuthnParameter;
import net.snowflake.client.core.auth.oauth.AuthorizationCodeFlowAccessTokenProvider;
import net.snowflake.client.core.auth.oauth.OauthAccessTokenProvider;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeDriver;
import net.snowflake.client.jdbc.SnowflakeReauthenticationRequest;
Expand Down Expand Up @@ -265,6 +267,13 @@ static SFLoginOutput openSession(
AssertUtil.assertTrue(
loginInput.getLoginTimeout() >= 0, "negative login timeout for opening session");

if (getAuthenticator(loginInput).equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE_FLOW)) {
OauthAccessTokenProvider accessTokenProvider = new AuthorizationCodeFlowAccessTokenProvider();
String oauthAccessToken = accessTokenProvider.getAccessToken(loginInput);
loginInput.setAuthenticator(AuthenticatorType.OAUTH.name());
loginInput.setToken(oauthAccessToken);
}

final AuthenticatorType authenticator = getAuthenticator(loginInput);
if (!authenticator.equals(AuthenticatorType.OAUTH)) {
// OAuth does not require a username
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,10 @@ public enum AuthenticatorType {
/*
* Authenticator to enable token for regular login with mfa
*/
USERNAME_PASSWORD_MFA
USERNAME_PASSWORD_MFA,

/*
* Authorization code flow with browser popup
*/
OAUTH_AUTHORIZATION_CODE_FLOW
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package net.snowflake.client.core.auth.oauth;

import com.nimbusds.oauth2.sdk.AccessTokenResponse;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.AuthorizationRequest;
import com.nimbusds.oauth2.sdk.ResponseType;
import com.nimbusds.oauth2.sdk.Scope;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.TokenResponse;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.auth.Secret;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.sun.net.httpserver.HttpServer;
import net.snowflake.client.core.HttpUtil;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SFLoginInput;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import org.apache.http.client.methods.HttpGet;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.concurrent.CompletableFuture;

@SnowflakeJdbcInternalApi
public class AuthorizationCodeFlowAccessTokenProvider implements OauthAccessTokenProvider {

private static final String REDIRECT_URI_HOST = "localhost";
private static final int REDIRECT_URI_PORT = 8001;
private static final String REDIRECT_URI_PATH = "/oauth-redirect";

@Override
public String getAccessToken(SFLoginInput loginInput) throws SFException {
AuthorizationCode authorizationCode = requestAuthorizationCode(loginInput);
AccessToken accessToken = exchangeAuthorizationCodeForAccessToken(loginInput, authorizationCode);
return accessToken.getValue();
}

private AuthorizationCode requestAuthorizationCode(SFLoginInput loginInput) throws SFException {
try {
AuthorizationRequest request = buildAuthorizationRequest(loginInput);

URI requestURI = request.toURI();
HttpUtil.executeGeneralRequest(new HttpGet(requestURI),
loginInput.getLoginTimeout(),
loginInput.getAuthTimeout(),
loginInput.getSocketTimeoutInMillis(),
0,
loginInput.getHttpClientSettingsKey());
CompletableFuture<String> f = getAuthorizationCodeFromRedirectURI();
f.join();
return new AuthorizationCode(f.get());
} catch (Exception e) {
throw new SFException(e, ErrorCode.INTERNAL_ERROR);
}
}

private static AccessToken exchangeAuthorizationCodeForAccessToken(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws SFException {
try {
TokenRequest request = buildTokenRequest(loginInput, authorizationCode);
TokenResponse response = TokenResponse.parse(request.toHTTPRequest().send());
if (!response.indicatesSuccess()) {
TokenErrorResponse errorResponse = response.toErrorResponse();
errorResponse.getErrorObject();
}
AccessTokenResponse successResponse = response.toSuccessResponse();
return successResponse.getTokens().getAccessToken();
} catch (Exception e) {
throw new SFException(e, ErrorCode.INTERNAL_ERROR);
}
}

private static CompletableFuture<String> getAuthorizationCodeFromRedirectURI() throws IOException {
CompletableFuture<String> accessTokenFuture = new CompletableFuture<>();
HttpServer httpServer = HttpServer.create(new InetSocketAddress(REDIRECT_URI_HOST, REDIRECT_URI_PORT), 0);
httpServer.createContext(REDIRECT_URI_PATH, exchange -> {
String authorizationCode = exchange.getRequestURI().getQuery();
accessTokenFuture.complete(authorizationCode);
httpServer.stop(0);
});
return accessTokenFuture;
}

private static AuthorizationRequest buildAuthorizationRequest(SFLoginInput loginInput) throws URISyntaxException {
URI authorizeEndpoint = new URI(String.format("%s/oauth/authorize", loginInput.getServerUrl()));
ClientID clientID = new ClientID("123");
Scope scope = new Scope("read", "write");
URI callback = buildRedirectURI();
State state = new State();
return new AuthorizationRequest.Builder(
new ResponseType(ResponseType.Value.CODE), clientID)
.scope(scope)
.state(state)
.redirectionURI(callback)
.endpointURI(authorizeEndpoint)
.build();
}

private static URI buildRedirectURI() throws URISyntaxException {
return new URI(String.format("https://%s:%s%s", REDIRECT_URI_HOST, REDIRECT_URI_PORT, REDIRECT_URI_PATH));
}

private static TokenRequest buildTokenRequest(SFLoginInput loginInput, AuthorizationCode authorizationCode) throws URISyntaxException {
URI callback = buildRedirectURI();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(authorizationCode, callback);
ClientID clientID = new ClientID("123");
Secret clientSecret = new Secret("123");
ClientAuthentication clientAuthentication = new ClientSecretBasic(clientID, clientSecret);
URI tokenEndpoint = new URI(String.format("%s/oauth/token", loginInput.getServerUrl()));
return new TokenRequest(tokenEndpoint, clientAuthentication, codeGrant, new Scope());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package net.snowflake.client.core.auth.oauth;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SFLoginInput;

public interface OauthAccessTokenProvider {

String getAccessToken(SFLoginInput loginInput) throws SFException;
}
5 changes: 5 additions & 0 deletions thin_public_pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@
<artifactId>nimbus-jose-jwt</artifactId>
<version>${nimbusds.version}</version>
</dependency>
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>oauth2-oidc-sdk</artifactId>
<version>${nimbusds.oauth2.version}</version>
</dependency>
<dependency>
<groupId>com.yammer.metrics</groupId>
<artifactId>metrics-core</artifactId>
Expand Down

0 comments on commit 1741680

Please sign in to comment.