Skip to content

Commit

Permalink
SNOW-1196041: Parameterize disabling default credentials for GCS client
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dprzybysz committed Mar 6, 2024
1 parent 2bd5370 commit b748574
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 45 deletions.
11 changes: 11 additions & 0 deletions src/main/java/net/snowflake/client/core/SFBaseSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ public abstract class SFBaseSession {
// we need to allow for it to maintain backwards compatibility.
private boolean enablePatternSearch = true;

/** Disable lookup for default credentials by GCS library */
private boolean disableGcsDefaultCredentials = false;

private Map<String, Object> commonParameters;

protected SFBaseSession(SFConnectionHandler sfConnectionHandler) {
Expand Down Expand Up @@ -726,6 +729,14 @@ public void setEnablePatternSearch(boolean enablePatternSearch) {
this.enablePatternSearch = enablePatternSearch;
}

public boolean getDisableGcsDefaultCredentials() {
return disableGcsDefaultCredentials;
}

public void setDisableGcsDefaultCredentials(boolean disableGcsDefaultCredentials) {
this.disableGcsDefaultCredentials = disableGcsDefaultCredentials;
}

public int getClientResultChunkSize() {
return clientResultChunkSize;
}
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ public void addSFSessionProperty(String propertyName, Object propertyValue) thro
setEnablePatternSearch(getBooleanValue(propertyValue));
}
break;
case DISABLE_GCS_DEFAULT_CREDENTIALS:
if (propertyValue != null) {
setDisableGcsDefaultCredentials(getBooleanValue(propertyValue));
}
break;

default:
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ public enum SFSessionProperty {

RETRY_TIMEOUT("retryTimeout", false, Integer.class),

ENABLE_PATTERN_SEARCH("enablePatternSearch", false, Boolean.class);
ENABLE_PATTERN_SEARCH("enablePatternSearch", false, Boolean.class),

DISABLE_GCS_DEFAULT_CREDENTIALS("disableGcsDefaultCredentials", false, Boolean.class);

// property key in string
private String propertyKey;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.api.gax.paging.Page;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.storage.Blob;
Expand Down Expand Up @@ -1201,14 +1202,21 @@ private void setupGCSClient(
try {
String accessToken = (String) stage.getCredentials().get("GCS_ACCESS_TOKEN");
if (accessToken != null) {
// We are authenticated with an oauth access token.
StorageOptions.Builder builder = StorageOptions.newBuilder();
if (session.getDisableGcsDefaultCredentials()) {
logger.debug(
"Adding explicit credentials to avoid default credential lookup by the GCS client");
builder.setCredentials(GoogleCredentials.create(new AccessToken(accessToken, null)));
}

// Using GoogleCredential with access token will cause IllegalStateException when the token
// is expired and trying to refresh, which cause error cannot be caught. Instead, set a
// header so we can caught the error code.

// We are authenticated with an oauth access token.
this.gcsClient =
StorageOptions.newBuilder()
.setCredentials(GoogleCredentials.create(new AccessToken(accessToken, null)))
builder
.setHeaderProvider(
FixedHeaderProvider.create("Authorization", "Bearer " + accessToken))
.build()
.getService();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
*/
public class StorageClientFactory {

private static final SFLogger logger = SFLoggerFactory.getLogger(SnowflakeS3Client.class);
private static final SFLogger logger = SFLoggerFactory.getLogger(StorageClientFactory.class);

private static StorageClientFactory factory;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import net.snowflake.client.core.Constants;
import net.snowflake.client.core.OCSPMode;
import net.snowflake.client.core.SFSession;
import net.snowflake.client.core.SFSessionProperty;
import net.snowflake.client.core.SFStatement;
import net.snowflake.client.jdbc.cloud.storage.SnowflakeStorageClient;
import net.snowflake.client.jdbc.cloud.storage.StageInfo;
Expand Down Expand Up @@ -1105,58 +1106,66 @@ private void testGeometryMetadataSingle(
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testPutGetGcsDownscopedCredential() throws Throwable {
Connection connection = null;
Statement statement = null;
Properties paramProperties = new Properties();
paramProperties.put("GCS_USE_DOWNSCOPED_CREDENTIAL", true);
try {
connection = getConnection("gcpaccount", paramProperties);
try (Connection connection = getConnection("gcpaccount", paramProperties);
Statement statement = connection.createStatement()) {
putAndGetFile(statement);
}
}

statement = connection.createStatement();
@Test
@ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class)
public void testPutGetGcsDownscopedCredentialWithDisabledDefaultCredentials() throws Throwable {
Properties paramProperties = new Properties();
paramProperties.put("GCS_USE_DOWNSCOPED_CREDENTIAL", true);
paramProperties.put(SFSessionProperty.DISABLE_GCS_DEFAULT_CREDENTIALS.getPropertyKey(), true);
try (Connection connection = getConnection("gcpaccount", paramProperties);
Statement statement = connection.createStatement()) {
putAndGetFile(statement);
}
}

String sourceFilePath = getFullPathFileInResource(TEST_DATA_FILE_2);
private void putAndGetFile(Statement statement) throws Throwable {
String sourceFilePath = getFullPathFileInResource(TEST_DATA_FILE_2);

File destFolder = tmpFolder.newFolder();
String destFolderCanonicalPath = destFolder.getCanonicalPath();
String destFolderCanonicalPathWithSeparator = destFolderCanonicalPath + File.separator;
File destFolder = tmpFolder.newFolder();
String destFolderCanonicalPath = destFolder.getCanonicalPath();
String destFolderCanonicalPathWithSeparator = destFolderCanonicalPath + File.separator;

try {
statement.execute("CREATE OR REPLACE STAGE testPutGet_stage");
try {
statement.execute("CREATE OR REPLACE STAGE testPutGet_stage");

assertTrue(
"Failed to put a file",
statement.execute("PUT file://" + sourceFilePath + " @testPutGet_stage"));
assertTrue(
"Failed to put a file",
statement.execute("PUT file://" + sourceFilePath + " @testPutGet_stage"));

findFile(statement, "ls @testPutGet_stage/");
findFile(statement, "ls @testPutGet_stage/");

// download the file we just uploaded to stage
assertTrue(
"Failed to get a file",
statement.execute(
"GET @testPutGet_stage 'file://" + destFolderCanonicalPath + "' parallel=8"));
// download the file we just uploaded to stage
assertTrue(
"Failed to get a file",
statement.execute(
"GET @testPutGet_stage 'file://" + destFolderCanonicalPath + "' parallel=8"));

// Make sure that the downloaded file exists, it should be gzip compressed
File downloaded = new File(destFolderCanonicalPathWithSeparator + TEST_DATA_FILE_2 + ".gz");
assert (downloaded.exists());
// Make sure that the downloaded file exists, it should be gzip compressed
File downloaded = new File(destFolderCanonicalPathWithSeparator + TEST_DATA_FILE_2 + ".gz");
assert (downloaded.exists());

Process p =
Runtime.getRuntime()
.exec("gzip -d " + destFolderCanonicalPathWithSeparator + TEST_DATA_FILE_2 + ".gz");
p.waitFor();
Process p =
Runtime.getRuntime()
.exec("gzip -d " + destFolderCanonicalPathWithSeparator + TEST_DATA_FILE_2 + ".gz");
p.waitFor();

File original = new File(sourceFilePath);
File unzipped = new File(destFolderCanonicalPathWithSeparator + TEST_DATA_FILE_2);
System.out.println(
"Original file: " + original.getAbsolutePath() + ", size: " + original.length());
System.out.println(
"Unzipped file: " + unzipped.getAbsolutePath() + ", size: " + unzipped.length());
assert (original.length() == unzipped.length());
} finally {
statement.execute("DROP STAGE IF EXISTS testGetPut_stage");
statement.close();
}
File original = new File(sourceFilePath);
File unzipped = new File(destFolderCanonicalPathWithSeparator + TEST_DATA_FILE_2);
System.out.println(
"Original file: " + original.getAbsolutePath() + ", size: " + original.length());
System.out.println(
"Unzipped file: " + unzipped.getAbsolutePath() + ", size: " + unzipped.length());
assert (original.length() == unzipped.length());
} finally {
closeSQLObjects(null, statement, connection);
statement.execute("DROP STAGE IF EXISTS testGetPut_stage");
}
}

Expand Down

0 comments on commit b748574

Please sign in to comment.