diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java index bd5a3945e..895275cef 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java @@ -1649,6 +1649,7 @@ private void uploadStream() throws SnowflakeSQLException { /** Download a file from remote, and return an input stream */ @Override public InputStream downloadStream(String fileName) throws SnowflakeSQLException { + logger.debug("Downloading file as stream: {}", fileName); if (stageInfo.getStageType() == StageInfo.StageType.LOCAL_FS) { logger.error("downloadStream function doesn't support local file system", false); @@ -1662,14 +1663,32 @@ public InputStream downloadStream(String fileName) throws SnowflakeSQLException remoteLocation remoteLocation = extractLocationAndPath(stageInfo.getLocation()); - String stageFilePath = fileName; + // when downloading files as stream there should be only one file in source files + String sourceLocation = + sourceFiles.stream() + .findFirst() + .orElseThrow( + () -> + new SnowflakeSQLException( + queryID, + SqlState.NO_DATA, + ErrorCode.FILE_NOT_FOUND.getMessageCode(), + session, + "File not found: " + fileName)); + + if (!fileName.equals(sourceLocation)) { + // filename may be different from source location e.g. in git repositories + logger.debug("Changing file to download location from {} to {}", fileName, sourceLocation); + } + String stageFilePath = sourceLocation; if (!remoteLocation.path.isEmpty()) { - stageFilePath = SnowflakeUtil.concatFilePathNames(remoteLocation.path, fileName, "/"); + stageFilePath = SnowflakeUtil.concatFilePathNames(remoteLocation.path, sourceLocation, "/"); } + logger.debug("Stage file path for {} is {}", sourceLocation, stageFilePath); - RemoteStoreFileEncryptionMaterial encMat = srcFileToEncMat.get(fileName); - String presignedUrl = srcFileToPresignedUrl.get(fileName); + RemoteStoreFileEncryptionMaterial encMat = srcFileToEncMat.get(sourceLocation); + String presignedUrl = srcFileToPresignedUrl.get(sourceLocation); return storageFactory .createClient(stageInfo, parallel, encMat, session) diff --git a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java index 4d2129a53..99ba7abdc 100644 --- a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java @@ -1316,7 +1316,7 @@ public void testDownloadStreamWithFileNotFoundException() throws SQLException { .unwrap(SnowflakeConnection.class) .downloadStream("@testDownloadStream_stage", "/fileNotExist.gz", true); } catch (SQLException ex) { - assertThat(ex.getErrorCode(), is(ErrorCode.S3_OPERATION_ERROR.getMessageCode())); + assertThat(ex.getErrorCode(), is(ErrorCode.FILE_NOT_FOUND.getMessageCode())); } long endDownloadTime = System.currentTimeMillis(); // S3Client retries some exception for a default timeout of 5 minutes diff --git a/src/test/java/net/snowflake/client/jdbc/GitRepositoryDownloadLatestIT.java b/src/test/java/net/snowflake/client/jdbc/GitRepositoryDownloadLatestIT.java new file mode 100644 index 000000000..b720591de --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/GitRepositoryDownloadLatestIT.java @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ +package net.snowflake.client.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import net.snowflake.client.ConditionalIgnoreRule; +import net.snowflake.client.RunningOnGithubAction; +import net.snowflake.client.category.TestCategoryOthers; +import org.apache.commons.io.IOUtils; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +@Category(TestCategoryOthers.class) +public class GitRepositoryDownloadLatestIT extends BaseJDBCTest { + + /** + * Test needs to set up git integration which is not available in GH Action tests and needs + * accountadmin role. Added in > 3.19.0 + */ + @Test + @ConditionalIgnoreRule.ConditionalIgnore(condition = RunningOnGithubAction.class) + public void shouldDownloadFileAndStreamFromGitRepository() throws Exception { + try (Connection connection = getConnection()) { + prepareJdbcRepoInSnowflake(connection); + + String stageName = + String.format("@%s.%s.JDBC", connection.getCatalog(), connection.getSchema()); + String fileName = ".pre-commit-config.yaml"; + String filePathInGitRepo = "branches/master/" + fileName; + + List fetchedFileContent = + getContentFromFile(connection, stageName, filePathInGitRepo, fileName); + + List fetchedStreamContent = + getContentFromStream(connection, stageName, filePathInGitRepo); + + assertFalse("File content cannot be empty", fetchedFileContent.isEmpty()); + assertFalse("Stream content cannot be empty", fetchedStreamContent.isEmpty()); + assertEquals(fetchedFileContent, fetchedStreamContent); + } + } + + private static void prepareJdbcRepoInSnowflake(Connection connection) throws SQLException { + try (Statement statement = connection.createStatement()) { + statement.execute("use role accountadmin"); + statement.execute( + "CREATE OR REPLACE API INTEGRATION gh_integration\n" + + " API_PROVIDER = git_https_api\n" + + " API_ALLOWED_PREFIXES = ('https://github.com/snowflakedb/snowflake-jdbc.git')\n" + + " ENABLED = TRUE;"); + statement.execute( + "CREATE OR REPLACE GIT REPOSITORY jdbc\n" + + "ORIGIN = 'https://github.com/snowflakedb/snowflake-jdbc.git'\n" + + "API_INTEGRATION = gh_integration;"); + } + } + + private static List getContentFromFile( + Connection connection, String stageName, String filePathInGitRepo, String fileName) + throws IOException, SQLException { + Path tempDir = Files.createTempDirectory("git"); + String stagePath = stageName + "/" + filePathInGitRepo; + Path downloadedFile = tempDir.resolve(fileName); + String command = String.format("GET '%s' '%s'", stagePath, tempDir.toUri()); + + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(command); ) { + // then + assertTrue("has result", rs.next()); + return Files.readAllLines(downloadedFile); + } finally { + Files.delete(downloadedFile); + Files.delete(tempDir); + } + } + + private static List getContentFromStream( + Connection connection, String stageName, String filePathInGitRepo) + throws SQLException, IOException { + SnowflakeConnection unwrap = connection.unwrap(SnowflakeConnection.class); + try (InputStream inputStream = unwrap.downloadStream(stageName, filePathInGitRepo, false)) { + return IOUtils.readLines(inputStream, StandardCharsets.UTF_8); + } + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/StreamLatestIT.java b/src/test/java/net/snowflake/client/jdbc/StreamLatestIT.java index 3ab179b70..093c2de27 100644 --- a/src/test/java/net/snowflake/client/jdbc/StreamLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/StreamLatestIT.java @@ -119,7 +119,7 @@ public void testDownloadToStreamBlobNotFoundGCS() throws SQLException { assertTrue(ex instanceof SQLException); assertTrue( "Wrong exception message: " + ex.getMessage(), - ex.getMessage().matches(".*Blob.*not found in bucket.*")); + ex.getMessage().contains("File not found")); } finally { statement.execute("rm @~/" + DEST_PREFIX); }