Skip to content

Commit

Permalink
SNOW-993600 External OAuth2.0 Support (#718)
Browse files Browse the repository at this point in the history
* Add external OAuth support
  • Loading branch information
sfc-gh-alhuang authored Mar 25, 2024
1 parent 4d5f072 commit dbd0ce0
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 159 deletions.
115 changes: 111 additions & 4 deletions src/main/java/net/snowflake/ingest/connection/OAuthClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,118 @@

package net.snowflake.ingest.connection;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import net.snowflake.client.jdbc.internal.apache.http.HttpHeaders;
import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse;
import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPost;
import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpUriRequest;
import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder;
import net.snowflake.client.jdbc.internal.apache.http.entity.ContentType;
import net.snowflake.client.jdbc.internal.apache.http.entity.StringEntity;
import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient;
import net.snowflake.client.jdbc.internal.apache.http.util.EntityUtils;
import net.snowflake.client.jdbc.internal.google.api.client.http.HttpStatusCodes;
import net.snowflake.client.jdbc.internal.google.gson.JsonObject;
import net.snowflake.client.jdbc.internal.google.gson.JsonParser;
import net.snowflake.ingest.utils.ErrorCode;
import net.snowflake.ingest.utils.HttpUtil;
import net.snowflake.ingest.utils.SFException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Interface to perform token refresh request from {@link OAuthManager} */
public interface OAuthClient {
AtomicReference<OAuthCredential> getoAuthCredentialRef();
/*
* Implementation of OAuth Client, used for refreshing an OAuth access token.
*/
public class OAuthClient {

static final Logger LOGGER = LoggerFactory.getLogger(OAuthClient.class);
private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";

// Content type header to specify the encoding
private static final String OAUTH_CONTENT_TYPE_HEADER = "application/x-www-form-urlencoded";
private static final String GRANT_TYPE_PARAM = "grant_type";
private static final String ACCESS_TOKEN = "access_token";
private static final String REFRESH_TOKEN = "refresh_token";
private static final String EXPIRES_IN = "expires_in";

// OAuth credential
private final AtomicReference<OAuthCredential> oAuthCredential;

// Http client for submitting token refresh request
private final CloseableHttpClient httpClient;

/**
* Creates an AuthClient for Snowflake OAuth given account, credential and base uri
*
* @param accountName - the snowflake account name of this user
* @param oAuthCredential - the OAuth credential we're using to connect
* @param baseURIBuilder - the uri builder with common scheme, host and port
*/
OAuthClient(String accountName, OAuthCredential oAuthCredential, URIBuilder baseURIBuilder) {
this.oAuthCredential = new AtomicReference<>(oAuthCredential);

// build token request uri
baseURIBuilder.setPath(TOKEN_REQUEST_ENDPOINT);
this.httpClient = HttpUtil.getHttpClient(accountName);
}

/** Get access token */
public AtomicReference<OAuthCredential> getOAuthCredentialRef() {
return oAuthCredential;
}

/** Refresh access token using a valid refresh token */
public void refreshToken() {
String respBodyString = null;
try (CloseableHttpResponse httpResponse = httpClient.execute(makeRefreshTokenRequest())) {
respBodyString = EntityUtils.toString(httpResponse.getEntity());

if (httpResponse.getStatusLine().getStatusCode() == HttpStatusCodes.STATUS_CODE_OK) {
JsonObject respBody = JsonParser.parseString(respBodyString).getAsJsonObject();

if (respBody.has(ACCESS_TOKEN) && respBody.has(EXPIRES_IN)) {
// Trim surrounding quotation marks
String newAccessToken = respBody.get(ACCESS_TOKEN).toString().replaceAll("^\"|\"$", "");
oAuthCredential.get().setAccessToken(newAccessToken);
oAuthCredential.get().setExpiresIn(respBody.get(EXPIRES_IN).getAsInt());
return;
}
}
throw new SFException(
ErrorCode.OAUTH_REFRESH_TOKEN_ERROR,
"Refresh access token fail with response: " + respBodyString);
} catch (IOException e) {
throw new SFException(ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, e.getMessage());
}
}

/** Helper method for making refresh request */
private HttpUriRequest makeRefreshTokenRequest() {
HttpPost post = new HttpPost(oAuthCredential.get().getOAuthTokenEndpoint());
post.addHeader(HttpHeaders.CONTENT_TYPE, OAUTH_CONTENT_TYPE_HEADER);
post.addHeader(HttpHeaders.AUTHORIZATION, oAuthCredential.get().getAuthHeader());

Map<Object, Object> payload = new HashMap<>();
try {
payload.put(GRANT_TYPE_PARAM, URLEncoder.encode(REFRESH_TOKEN, "UTF-8"));
payload.put(
REFRESH_TOKEN, URLEncoder.encode(oAuthCredential.get().getRefreshToken(), "UTF-8"));
} catch (UnsupportedEncodingException e) {
throw new SFException(e, ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, e.getMessage());
}

String payloadString =
payload.entrySet().stream()
.map(e -> e.getKey() + "=" + e.getValue())
.collect(Collectors.joining("&"));
post.setEntity(new StringEntity(payloadString, ContentType.APPLICATION_FORM_URLENCODED));

void refreshToken();
return post;
}
}
20 changes: 14 additions & 6 deletions src/main/java/net/snowflake/ingest/connection/OAuthCredential.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,29 @@

package net.snowflake.ingest.connection;

import java.net.URI;
import java.util.Base64;

/** This class hold credentials for OAuth authentication */
public class OAuthCredential {
private static final String BASIC_AUTH_HEADER_PREFIX = "Basic ";
private final String authHeader;
private final String clientId;
private final String clientSecret;
private String accessToken;
private String refreshToken;
private final URI oAuthTokenEndpoint;
private transient String accessToken;
private transient String refreshToken;
private int expiresIn;

public OAuthCredential(String clientId, String clientSecret, String refreshToken) {
this(clientId, clientSecret, refreshToken, null);
}

public OAuthCredential(
String clientId, String clientSecret, String refreshToken, URI oAuthTokenEndpoint) {
this.authHeader =
BASIC_AUTH_HEADER_PREFIX
+ Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes());
this.clientId = clientId;
this.clientSecret = clientSecret;
this.refreshToken = refreshToken;
this.oAuthTokenEndpoint = oAuthTokenEndpoint;
}

public String getAuthHeader() {
Expand All @@ -45,6 +49,10 @@ public void setRefreshToken(String refreshToken) {
this.refreshToken = refreshToken;
}

public URI getOAuthTokenEndpoint() {
return oAuthTokenEndpoint;
}

public void setExpiresIn(int expiresIn) {
this.expiresIn = expiresIn;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public final class OAuthManager extends SecurityManager {
throw new IllegalArgumentException("updateThresholdRatio should fall in (0, 1)");
}
this.updateThresholdRatio = updateThresholdRatio;
this.oAuthClient = new SnowflakeOAuthClient(accountName, oAuthCredential, baseURIBuilder);
this.oAuthClient = new OAuthClient(accountName, oAuthCredential, baseURIBuilder);

// generate our first token
refreshToken();
Expand Down Expand Up @@ -119,7 +119,7 @@ String getToken() {
if (refreshFailed.get()) {
throw new SecurityException("getToken request failed due to token refresh failure");
}
return oAuthClient.getoAuthCredentialRef().get().getAccessToken();
return oAuthClient.getOAuthCredentialRef().get().getAccessToken();
}

@Override
Expand All @@ -134,7 +134,7 @@ String getTokenType() {
* @param refreshToken the new refresh token
*/
void setRefreshToken(String refreshToken) {
oAuthClient.getoAuthCredentialRef().get().setRefreshToken(refreshToken);
oAuthClient.getOAuthCredentialRef().get().setRefreshToken(refreshToken);
}

/** refreshToken - Get new access token using refresh_token, client_id, client_secret */
Expand All @@ -147,7 +147,7 @@ void refreshToken() {
// Schedule next refresh
long nextRefreshDelay =
(long)
(oAuthClient.getoAuthCredentialRef().get().getExpiresIn()
(oAuthClient.getOAuthCredentialRef().get().getExpiresIn()
* this.updateThresholdRatio);
tokenRefresher.schedule(this::refreshToken, nextRefreshDelay, TimeUnit.SECONDS);
LOGGER.debug(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.spec.InvalidKeySpecException;
Expand All @@ -60,6 +62,7 @@
import javax.management.MalformedObjectNameException;
import javax.management.ObjectName;
import net.snowflake.client.core.SFSessionProperty;
import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder;
import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient;
import net.snowflake.ingest.connection.IngestResponseException;
import net.snowflake.ingest.connection.OAuthCredential;
Expand Down Expand Up @@ -187,11 +190,29 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
throw new SFException(e, ErrorCode.KEYPAIR_CREATION_FAILURE);
}
} else {
URI oAuthTokenEndpoint;
try {
if (prop.getProperty(Constants.OAUTH_TOKEN_ENDPOINT) == null) {
// Set OAuth token endpoint to Snowflake OAuth by default
oAuthTokenEndpoint =
new URIBuilder()
.setScheme(accountURL.getScheme())
.setHost(accountURL.getUrlWithoutPort())
.setPort(accountURL.getPort())
.setPath(Constants.SNOWFLAKE_OAUTH_TOKEN_ENDPOINT)
.build();
} else {
oAuthTokenEndpoint = new URI(prop.getProperty(Constants.OAUTH_TOKEN_ENDPOINT));
}
} catch (URISyntaxException e) {
throw new SFException(e, ErrorCode.INVALID_URL);
}
credential =
new OAuthCredential(
prop.getProperty(Constants.OAUTH_CLIENT_ID),
prop.getProperty(Constants.OAUTH_CLIENT_SECRET),
prop.getProperty(Constants.OAUTH_REFRESH_TOKEN));
prop.getProperty(Constants.OAUTH_REFRESH_TOKEN),
oAuthTokenEndpoint);
}
this.requestBuilder =
new RequestBuilder(
Expand Down
Loading

0 comments on commit dbd0ce0

Please sign in to comment.