diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java index bd5a3945e..9f9141989 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeFileTransferAgent.java @@ -1662,14 +1662,34 @@ public InputStream downloadStream(String fileName) throws SnowflakeSQLException remoteLocation remoteLocation = extractLocationAndPath(stageInfo.getLocation()); - String stageFilePath = fileName; + // when downloading files as stream there should be only one file in source files + String sourceLocation = + sourceFiles.stream() + .findFirst() + .orElseThrow( + () -> + new SnowflakeSQLException( + queryID, + SqlState.NO_DATA, + ErrorCode.FILE_NOT_FOUND.getMessageCode(), + session, + "File not found: " + fileName)); + + String stageFilePath; + if (fileName.equals(sourceLocation)) { + stageFilePath = fileName; + } else { + // filename may be different from source location e.g. in git repositories + logger.debug("Changing file to download location from {} to {}", fileName, sourceLocation); + stageFilePath = sourceLocation; + } if (!remoteLocation.path.isEmpty()) { - stageFilePath = SnowflakeUtil.concatFilePathNames(remoteLocation.path, fileName, "/"); + stageFilePath = SnowflakeUtil.concatFilePathNames(remoteLocation.path, sourceLocation, "/"); } - RemoteStoreFileEncryptionMaterial encMat = srcFileToEncMat.get(fileName); - String presignedUrl = srcFileToPresignedUrl.get(fileName); + RemoteStoreFileEncryptionMaterial encMat = srcFileToEncMat.get(sourceLocation); + String presignedUrl = srcFileToPresignedUrl.get(sourceLocation); return storageFactory .createClient(stageInfo, parallel, encMat, session) diff --git a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java index efed33896..0b72e8364 100644 --- a/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java +++ b/src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java @@ -1313,7 +1313,7 @@ public void testDownloadStreamWithFileNotFoundException() throws SQLException { .unwrap(SnowflakeConnection.class) .downloadStream("@testDownloadStream_stage", "/fileNotExist.gz", true); } catch (SQLException ex) { - assertThat(ex.getErrorCode(), is(ErrorCode.S3_OPERATION_ERROR.getMessageCode())); + assertThat(ex.getErrorCode(), is(ErrorCode.FILE_NOT_FOUND.getMessageCode())); } long endDownloadTime = System.currentTimeMillis(); // S3Client retries some exception for a default timeout of 5 minutes diff --git a/src/test/java/net/snowflake/client/jdbc/GitRepositoryDownloadLatestIT.java b/src/test/java/net/snowflake/client/jdbc/GitRepositoryDownloadLatestIT.java new file mode 100644 index 000000000..b4780fb95 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/GitRepositoryDownloadLatestIT.java @@ -0,0 +1,81 @@ +package net.snowflake.client.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import net.snowflake.client.category.TestCategoryOthers; +import org.apache.commons.io.IOUtils; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +@Category(TestCategoryOthers.class) +public class GitRepositoryDownloadLatestIT extends BaseJDBCWithSharedConnectionIT { + @Test + public void shouldDownloadFileAndStreamFromGitRepository() throws Exception { + prepareJdbcRepoInSnowflake(); + + String stageName = + String.format("@%s.%s.JDBC", connection.getCatalog(), connection.getSchema()); + String fileName = ".pre-commit-config.yaml"; + String filePathInGitRepo = "branches/master/" + fileName; + + List fetchedFileContent = getContentFromFile(stageName, filePathInGitRepo, fileName); + + List fetchedStreamContent = getContentFromStream(stageName, filePathInGitRepo); + + assertFalse("File content cannot be empty", fetchedFileContent.isEmpty()); + assertFalse("Stream content cannot be empty", fetchedStreamContent.isEmpty()); + assertEquals(fetchedFileContent, fetchedStreamContent); + } + + private static void prepareJdbcRepoInSnowflake() throws SQLException { + try (Statement statement = connection.createStatement()) { + statement.execute( + "CREATE OR REPLACE API INTEGRATION gh_integration\n" + + " API_PROVIDER = git_https_api\n" + + " API_ALLOWED_PREFIXES = ('https://github.com/snowflakedb/snowflake-jdbc.git')\n" + + " ENABLED = TRUE;"); + statement.execute( + "CREATE OR REPLACE GIT REPOSITORY jdbc\n" + + "ORIGIN = 'https://github.com/snowflakedb/snowflake-jdbc.git'\n" + + "API_INTEGRATION = gh_integration;"); + } + } + + private static List getContentFromFile( + String stageName, String filePathInGitRepo, String fileName) + throws IOException, SQLException { + Path tempDir = Files.createTempDirectory("git"); + String stagePath = stageName + "/" + filePathInGitRepo; + Path downloadedFile = tempDir.resolve(fileName); + String command = String.format("GET '%s' '%s'", stagePath, tempDir.toUri()); + + try (Statement statement = connection.createStatement(); + ResultSet rs = statement.executeQuery(command); ) { + // then + assertTrue("has result", rs.next()); + return Files.readAllLines(downloadedFile); + } finally { + Files.delete(downloadedFile); + Files.delete(tempDir); + } + } + + private static List getContentFromStream(String stageName, String filePathInGitRepo) + throws SQLException, IOException { + SnowflakeConnection unwrap = connection.unwrap(SnowflakeConnection.class); + try (InputStream inputStream = unwrap.downloadStream(stageName, filePathInGitRepo, false)) { + return IOUtils.readLines(inputStream, StandardCharsets.UTF_8); + } + } +}