diff --git a/src/main/java/net/snowflake/ingest/connection/OAuthClient.java b/src/main/java/net/snowflake/ingest/connection/OAuthClient.java index bdff6c96e..bc2b64089 100644 --- a/src/main/java/net/snowflake/ingest/connection/OAuthClient.java +++ b/src/main/java/net/snowflake/ingest/connection/OAuthClient.java @@ -4,11 +4,118 @@ package net.snowflake.ingest.connection; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import net.snowflake.client.jdbc.internal.apache.http.HttpHeaders; +import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; +import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPost; +import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpUriRequest; +import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder; +import net.snowflake.client.jdbc.internal.apache.http.entity.ContentType; +import net.snowflake.client.jdbc.internal.apache.http.entity.StringEntity; +import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; +import net.snowflake.client.jdbc.internal.apache.http.util.EntityUtils; +import net.snowflake.client.jdbc.internal.google.api.client.http.HttpStatusCodes; +import net.snowflake.client.jdbc.internal.google.gson.JsonObject; +import net.snowflake.client.jdbc.internal.google.gson.JsonParser; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.HttpUtil; +import net.snowflake.ingest.utils.SFException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -/** Interface to perform token refresh request from {@link OAuthManager} */ -public interface OAuthClient { - AtomicReference getoAuthCredentialRef(); +/* + * Implementation of OAuth Client, used for refreshing an OAuth access token. + */ +public class OAuthClient { + + static final Logger LOGGER = LoggerFactory.getLogger(OAuthClient.class); + private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; + + // Content type header to specify the encoding + private static final String OAUTH_CONTENT_TYPE_HEADER = "application/x-www-form-urlencoded"; + private static final String GRANT_TYPE_PARAM = "grant_type"; + private static final String ACCESS_TOKEN = "access_token"; + private static final String REFRESH_TOKEN = "refresh_token"; + private static final String EXPIRES_IN = "expires_in"; + + // OAuth credential + private final AtomicReference oAuthCredential; + + // Http client for submitting token refresh request + private final CloseableHttpClient httpClient; + + /** + * Creates an AuthClient for Snowflake OAuth given account, credential and base uri + * + * @param accountName - the snowflake account name of this user + * @param oAuthCredential - the OAuth credential we're using to connect + * @param baseURIBuilder - the uri builder with common scheme, host and port + */ + OAuthClient(String accountName, OAuthCredential oAuthCredential, URIBuilder baseURIBuilder) { + this.oAuthCredential = new AtomicReference<>(oAuthCredential); + + // build token request uri + baseURIBuilder.setPath(TOKEN_REQUEST_ENDPOINT); + this.httpClient = HttpUtil.getHttpClient(accountName); + } + + /** Get access token */ + public AtomicReference getOAuthCredentialRef() { + return oAuthCredential; + } + + /** Refresh access token using a valid refresh token */ + public void refreshToken() { + String respBodyString = null; + try (CloseableHttpResponse httpResponse = httpClient.execute(makeRefreshTokenRequest())) { + respBodyString = EntityUtils.toString(httpResponse.getEntity()); + + if (httpResponse.getStatusLine().getStatusCode() == HttpStatusCodes.STATUS_CODE_OK) { + JsonObject respBody = JsonParser.parseString(respBodyString).getAsJsonObject(); + + if (respBody.has(ACCESS_TOKEN) && respBody.has(EXPIRES_IN)) { + // Trim surrounding quotation marks + String newAccessToken = respBody.get(ACCESS_TOKEN).toString().replaceAll("^\"|\"$", ""); + oAuthCredential.get().setAccessToken(newAccessToken); + oAuthCredential.get().setExpiresIn(respBody.get(EXPIRES_IN).getAsInt()); + return; + } + } + throw new SFException( + ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, + "Refresh access token fail with response: " + respBodyString); + } catch (IOException e) { + throw new SFException(ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, e.getMessage()); + } + } + + /** Helper method for making refresh request */ + private HttpUriRequest makeRefreshTokenRequest() { + HttpPost post = new HttpPost(oAuthCredential.get().getOAuthTokenEndpoint()); + post.addHeader(HttpHeaders.CONTENT_TYPE, OAUTH_CONTENT_TYPE_HEADER); + post.addHeader(HttpHeaders.AUTHORIZATION, oAuthCredential.get().getAuthHeader()); + + Map payload = new HashMap<>(); + try { + payload.put(GRANT_TYPE_PARAM, URLEncoder.encode(REFRESH_TOKEN, "UTF-8")); + payload.put( + REFRESH_TOKEN, URLEncoder.encode(oAuthCredential.get().getRefreshToken(), "UTF-8")); + } catch (UnsupportedEncodingException e) { + throw new SFException(e, ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, e.getMessage()); + } + + String payloadString = + payload.entrySet().stream() + .map(e -> e.getKey() + "=" + e.getValue()) + .collect(Collectors.joining("&")); + post.setEntity(new StringEntity(payloadString, ContentType.APPLICATION_FORM_URLENCODED)); - void refreshToken(); + return post; + } } diff --git a/src/main/java/net/snowflake/ingest/connection/OAuthCredential.java b/src/main/java/net/snowflake/ingest/connection/OAuthCredential.java index e222478d1..a5136d7cb 100644 --- a/src/main/java/net/snowflake/ingest/connection/OAuthCredential.java +++ b/src/main/java/net/snowflake/ingest/connection/OAuthCredential.java @@ -4,25 +4,29 @@ package net.snowflake.ingest.connection; +import java.net.URI; import java.util.Base64; /** This class hold credentials for OAuth authentication */ public class OAuthCredential { private static final String BASIC_AUTH_HEADER_PREFIX = "Basic "; private final String authHeader; - private final String clientId; - private final String clientSecret; - private String accessToken; - private String refreshToken; + private final URI oAuthTokenEndpoint; + private transient String accessToken; + private transient String refreshToken; private int expiresIn; public OAuthCredential(String clientId, String clientSecret, String refreshToken) { + this(clientId, clientSecret, refreshToken, null); + } + + public OAuthCredential( + String clientId, String clientSecret, String refreshToken, URI oAuthTokenEndpoint) { this.authHeader = BASIC_AUTH_HEADER_PREFIX + Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()); - this.clientId = clientId; - this.clientSecret = clientSecret; this.refreshToken = refreshToken; + this.oAuthTokenEndpoint = oAuthTokenEndpoint; } public String getAuthHeader() { @@ -45,6 +49,10 @@ public void setRefreshToken(String refreshToken) { this.refreshToken = refreshToken; } + public URI getOAuthTokenEndpoint() { + return oAuthTokenEndpoint; + } + public void setExpiresIn(int expiresIn) { this.expiresIn = expiresIn; } diff --git a/src/main/java/net/snowflake/ingest/connection/OAuthManager.java b/src/main/java/net/snowflake/ingest/connection/OAuthManager.java index fefb94299..a30396e7b 100644 --- a/src/main/java/net/snowflake/ingest/connection/OAuthManager.java +++ b/src/main/java/net/snowflake/ingest/connection/OAuthManager.java @@ -88,7 +88,7 @@ public final class OAuthManager extends SecurityManager { throw new IllegalArgumentException("updateThresholdRatio should fall in (0, 1)"); } this.updateThresholdRatio = updateThresholdRatio; - this.oAuthClient = new SnowflakeOAuthClient(accountName, oAuthCredential, baseURIBuilder); + this.oAuthClient = new OAuthClient(accountName, oAuthCredential, baseURIBuilder); // generate our first token refreshToken(); @@ -119,7 +119,7 @@ String getToken() { if (refreshFailed.get()) { throw new SecurityException("getToken request failed due to token refresh failure"); } - return oAuthClient.getoAuthCredentialRef().get().getAccessToken(); + return oAuthClient.getOAuthCredentialRef().get().getAccessToken(); } @Override @@ -134,7 +134,7 @@ String getTokenType() { * @param refreshToken the new refresh token */ void setRefreshToken(String refreshToken) { - oAuthClient.getoAuthCredentialRef().get().setRefreshToken(refreshToken); + oAuthClient.getOAuthCredentialRef().get().setRefreshToken(refreshToken); } /** refreshToken - Get new access token using refresh_token, client_id, client_secret */ @@ -147,7 +147,7 @@ void refreshToken() { // Schedule next refresh long nextRefreshDelay = (long) - (oAuthClient.getoAuthCredentialRef().get().getExpiresIn() + (oAuthClient.getOAuthCredentialRef().get().getExpiresIn() * this.updateThresholdRatio); tokenRefresher.schedule(this::refreshToken, nextRefreshDelay, TimeUnit.SECONDS); LOGGER.debug( diff --git a/src/main/java/net/snowflake/ingest/connection/SnowflakeOAuthClient.java b/src/main/java/net/snowflake/ingest/connection/SnowflakeOAuthClient.java deleted file mode 100644 index a90540659..000000000 --- a/src/main/java/net/snowflake/ingest/connection/SnowflakeOAuthClient.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. - */ - -package net.snowflake.ingest.connection; - -import java.io.UnsupportedEncodingException; -import java.net.URI; -import java.net.URISyntaxException; -import java.net.URLEncoder; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; -import net.snowflake.client.jdbc.internal.apache.http.HttpHeaders; -import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; -import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPost; -import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpUriRequest; -import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder; -import net.snowflake.client.jdbc.internal.apache.http.entity.ContentType; -import net.snowflake.client.jdbc.internal.apache.http.entity.StringEntity; -import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; -import net.snowflake.client.jdbc.internal.apache.http.util.EntityUtils; -import net.snowflake.client.jdbc.internal.google.api.client.http.HttpStatusCodes; -import net.snowflake.client.jdbc.internal.google.gson.JsonObject; -import net.snowflake.client.jdbc.internal.google.gson.JsonParser; -import net.snowflake.ingest.utils.ErrorCode; -import net.snowflake.ingest.utils.HttpUtil; -import net.snowflake.ingest.utils.SFException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/* - * Implementation of Snowflake OAuth Client, used for refreshing an OAuth access token. - */ -public class SnowflakeOAuthClient implements OAuthClient { - - static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeOAuthClient.class); - private static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request"; - - // Content type header to specify the encoding - private static final String OAUTH_CONTENT_TYPE_HEADER = "application/x-www-form-urlencoded"; - private static final String GRANT_TYPE_PARAM = "grant_type"; - private static final String ACCESS_TOKEN = "access_token"; - private static final String REFRESH_TOKEN = "refresh_token"; - private static final String EXPIRES_IN = "expires_in"; - - // OAuth credential - private final AtomicReference oAuthCredential; - - // exact uri for token request - private final URI tokenRequestURI; - - // Http client for submitting token refresh request - private final CloseableHttpClient httpClient; - - /** - * Creates an SnowflakeOAuthClient given account, credential and base uri - * - * @param accountName - the snowflake account name of this user - * @param oAuthCredential - the OAuth credential we're using to connect - * @param baseURIBuilder - the uri builder with common scheme, host and port - */ - SnowflakeOAuthClient( - String accountName, OAuthCredential oAuthCredential, URIBuilder baseURIBuilder) { - this.oAuthCredential = new AtomicReference<>(oAuthCredential); - - // build token request uri - baseURIBuilder.setPath(TOKEN_REQUEST_ENDPOINT); - try { - this.tokenRequestURI = baseURIBuilder.build(); - } catch (URISyntaxException e) { - throw new SFException(e, ErrorCode.MAKE_URI_FAILURE, e.getMessage()); - } - - this.httpClient = HttpUtil.getHttpClient(accountName); - } - - /** Get access token */ - @Override - public AtomicReference getoAuthCredentialRef() { - return oAuthCredential; - } - - /** Refresh access token using a valid refresh token */ - @Override - public void refreshToken() { - String respBodyString = null; - try (CloseableHttpResponse httpResponse = httpClient.execute(makeRefreshTokenRequest())) { - respBodyString = EntityUtils.toString(httpResponse.getEntity()); - - if (httpResponse.getStatusLine().getStatusCode() == HttpStatusCodes.STATUS_CODE_OK) { - JsonObject respBody = JsonParser.parseString(respBodyString).getAsJsonObject(); - - if (respBody.has(ACCESS_TOKEN) && respBody.has(EXPIRES_IN)) { - // Trim surrounding quotation marks - String newAccessToken = respBody.get(ACCESS_TOKEN).toString().replaceAll("^\"|\"$", ""); - - oAuthCredential.get().setAccessToken(newAccessToken); - oAuthCredential.get().setExpiresIn(respBody.get(EXPIRES_IN).getAsInt()); - return; - } - } - throw new SFException( - ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, - "Refresh access token fail with response: " + respBodyString); - } catch (Exception e) { - throw new SFException(ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, e.getMessage()); - } - } - - /** Helper method for making refresh request */ - private HttpUriRequest makeRefreshTokenRequest() { - HttpPost post = new HttpPost(tokenRequestURI); - post.addHeader(HttpHeaders.CONTENT_TYPE, OAUTH_CONTENT_TYPE_HEADER); - post.addHeader(HttpHeaders.AUTHORIZATION, oAuthCredential.get().getAuthHeader()); - - Map payload = new HashMap<>(); - try { - payload.put(GRANT_TYPE_PARAM, URLEncoder.encode(REFRESH_TOKEN, "UTF-8")); - payload.put( - REFRESH_TOKEN, URLEncoder.encode(oAuthCredential.get().getRefreshToken(), "UTF-8")); - } catch (UnsupportedEncodingException e) { - throw new SFException(e, ErrorCode.OAUTH_REFRESH_TOKEN_ERROR, e.getMessage()); - } - - String payloadString = - payload.entrySet().stream() - .map(e -> e.getKey() + "=" + e.getValue()) - .collect(Collectors.joining("&")); - - final StringEntity entity = - new StringEntity(payloadString, ContentType.APPLICATION_FORM_URLENCODED); - post.setEntity(entity); - - return post; - } -} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index 3c4896ee1..2990b49d8 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -40,6 +40,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.spec.InvalidKeySpecException; @@ -60,6 +62,7 @@ import javax.management.MalformedObjectNameException; import javax.management.ObjectName; import net.snowflake.client.core.SFSessionProperty; +import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder; import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; import net.snowflake.ingest.connection.IngestResponseException; import net.snowflake.ingest.connection.OAuthCredential; @@ -187,11 +190,29 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea throw new SFException(e, ErrorCode.KEYPAIR_CREATION_FAILURE); } } else { + URI oAuthTokenEndpoint; + try { + if (prop.getProperty(Constants.OAUTH_TOKEN_ENDPOINT) == null) { + // Set OAuth token endpoint to Snowflake OAuth by default + oAuthTokenEndpoint = + new URIBuilder() + .setScheme(accountURL.getScheme()) + .setHost(accountURL.getUrlWithoutPort()) + .setPort(accountURL.getPort()) + .setPath(Constants.SNOWFLAKE_OAUTH_TOKEN_ENDPOINT) + .build(); + } else { + oAuthTokenEndpoint = new URI(prop.getProperty(Constants.OAUTH_TOKEN_ENDPOINT)); + } + } catch (URISyntaxException e) { + throw new SFException(e, ErrorCode.INVALID_URL); + } credential = new OAuthCredential( prop.getProperty(Constants.OAUTH_CLIENT_ID), prop.getProperty(Constants.OAUTH_CLIENT_SECRET), - prop.getProperty(Constants.OAUTH_REFRESH_TOKEN)); + prop.getProperty(Constants.OAUTH_REFRESH_TOKEN), + oAuthTokenEndpoint); } this.requestBuilder = new RequestBuilder( diff --git a/src/main/java/net/snowflake/ingest/utils/Constants.java b/src/main/java/net/snowflake/ingest/utils/Constants.java index 134328668..096bb4ed4 100644 --- a/src/main/java/net/snowflake/ingest/utils/Constants.java +++ b/src/main/java/net/snowflake/ingest/utils/Constants.java @@ -31,6 +31,8 @@ public class Constants { public static final String OAUTH_CLIENT_ID = "oauth_client_id"; public static final String OAUTH_CLIENT_SECRET = "oauth_client_secret"; public static final String OAUTH_REFRESH_TOKEN = "oauth_refresh_token"; + public static final String OAUTH_TOKEN_ENDPOINT = "oauth_token_endpoint"; + public static final String SNOWFLAKE_OAUTH_TOKEN_ENDPOINT = "/oauth/token-request"; public static final String PRIMARY_FILE_ID_KEY = "primaryFileId"; // Don't change, should match Parquet Scanner public static final long RESPONSE_SUCCESS = 0L; // Don't change, should match server side diff --git a/src/test/java/net/snowflake/ingest/connection/MockOAuthClient.java b/src/test/java/net/snowflake/ingest/connection/MockOAuthClient.java index f1c3b2409..ee98bd04d 100644 --- a/src/test/java/net/snowflake/ingest/connection/MockOAuthClient.java +++ b/src/test/java/net/snowflake/ingest/connection/MockOAuthClient.java @@ -2,24 +2,28 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicReference; +import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; /** Mock implementation of {@link OAuthClient}, only use for test */ -public class MockOAuthClient implements OAuthClient { +public class MockOAuthClient extends OAuthClient { private final AtomicReference oAuthCredential; private int futureRefreshFailCount = 0; public MockOAuthClient() { - OAuthCredential mockOAuthCredential = - new OAuthCredential("CLIENT_ID", "CLIENT_SECRET", "REFRESH_TOKEN"); - oAuthCredential = new AtomicReference<>(mockOAuthCredential); + super( + "ACCOUNT_NAME", + new OAuthCredential("CLIENT_ID", "CLIENT_SECRET", "REFRESH_TOKEN"), + new URIBuilder()); + oAuthCredential = + new AtomicReference<>(new OAuthCredential("CLIENT_ID", "CLIENT_SECRET", "REFRESH_TOKEN")); oAuthCredential.get().setExpiresIn(600); } @Override - public AtomicReference getoAuthCredentialRef() { + public AtomicReference getOAuthCredentialRef() { return oAuthCredential; } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/OAuthBasicTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/OAuthBasicTest.java index 614f606a9..6351f267d 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/OAuthBasicTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/OAuthBasicTest.java @@ -80,7 +80,24 @@ public void missingOAuthParam() throws Exception { new SFException(ErrorCode.MISSING_CONFIG, Constants.OAUTH_REFRESH_TOKEN).getMessage()); } - /** Create client with mock credential, should fail when refreshing token */ + /** + * Create a client with mock credential using snowflake oauth, should fail when refreshing token + */ + @Test(expected = SecurityException.class) + public void testCreateSnowflakeOAuthClient() throws Exception { + Properties props = TestUtils.getProperties(Constants.BdecVersion.THREE, false); + props.remove(Constants.PRIVATE_KEY); + props.put(Constants.AUTHORIZATION_TYPE, Constants.OAUTH); + props.put(Constants.OAUTH_CLIENT_ID, "MOCK_CLIENT_ID"); + props.put(Constants.OAUTH_CLIENT_SECRET, "MOCK_CLIENT_SECRET"); + props.put(Constants.OAUTH_REFRESH_TOKEN, "MOCK_REFRESH_TOKEN"); + SnowflakeStreamingIngestClient client = + SnowflakeStreamingIngestClientFactory.builder("MY_CLIENT").setProperties(props).build(); + } + + /** + * Create a client with mock credential using external oauth, should fail when refreshing token + */ @Test(expected = SecurityException.class) public void testCreateOAuthClient() throws Exception { Properties props = TestUtils.getProperties(Constants.BdecVersion.THREE, false); @@ -89,6 +106,7 @@ public void testCreateOAuthClient() throws Exception { props.put(Constants.OAUTH_CLIENT_ID, "MOCK_CLIENT_ID"); props.put(Constants.OAUTH_CLIENT_SECRET, "MOCK_CLIENT_SECRET"); props.put(Constants.OAUTH_REFRESH_TOKEN, "MOCK_REFRESH_TOKEN"); + props.put(Constants.OAUTH_TOKEN_ENDPOINT, "https://mockexternaloauthendpoint.test/token"); SnowflakeStreamingIngestClient client = SnowflakeStreamingIngestClientFactory.builder("MY_CLIENT").setProperties(props).build(); }