From 78ddc27caf1fa7f6b1bfd29b1bab5ca70755b620 Mon Sep 17 00:00:00 2001 From: Dariusz Stempniak Date: Wed, 13 Sep 2023 09:53:11 +0200 Subject: [PATCH] SNOW-893835 Arrow - support for multi-chunk results (#771) ### Description Arrow - support for multi-chunk results. ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- .../IntegrationTests/SFDbCommandIT.cs | 43 ++- .../IntegrationTests/SFReusableChunkTest.cs | 2 +- .../UnitTests/ArrowChunkParserTest.cs | 64 ++++ .../UnitTests/ArrowResultChunkTest.cs | 140 ++++++++- .../UnitTests/ArrowResultSetTest.cs | 17 +- .../UnitTests/ChunkDeserializerTest.cs | 52 ++-- .../UnitTests/ChunkDownloaderFactoryTest.cs | 7 +- .../UnitTests/ChunkParserFactoryTest.cs | 4 +- .../UnitTests/ChunkStreamingParserTest.cs | 55 ++-- .../UnitTests/SFReusableChunkTest.cs | 288 +++++++++--------- Snowflake.Data/Core/ArrowChunkParser.cs | 37 +++ Snowflake.Data/Core/ArrowResultChunk.cs | 101 ++++-- Snowflake.Data/Core/ArrowResultSet.cs | 52 ++-- Snowflake.Data/Core/BaseResultChunk.cs | 44 +++ Snowflake.Data/Core/ChunkDeserializer.cs | 2 +- Snowflake.Data/Core/ChunkDownloaderFactory.cs | 3 +- Snowflake.Data/Core/ChunkParserFactory.cs | 5 +- Snowflake.Data/Core/ChunkStreamingParser.cs | 61 ++-- Snowflake.Data/Core/IChunkDownloader.cs | 2 +- Snowflake.Data/Core/IChunkParserFactory.cs | 2 +- Snowflake.Data/Core/IResultChunk.cs | 42 +-- Snowflake.Data/Core/ReusableChunkParser.cs | 6 +- Snowflake.Data/Core/SFBaseResultSet.cs | 9 + .../Core/SFBlockingChunkDownloader.cs | 25 +- .../Core/SFBlockingChunkDownloaderV3.cs | 66 ++-- Snowflake.Data/Core/SFChunkDownloaderV2.cs | 31 +- Snowflake.Data/Core/SFResultChunk.cs | 65 ++-- Snowflake.Data/Core/SFResultSet.cs | 109 ++----- Snowflake.Data/Core/SFReusableChunk.cs | 56 ++-- 29 files changed, 844 insertions(+), 546 deletions(-) create mode 100644 Snowflake.Data.Tests/UnitTests/ArrowChunkParserTest.cs create mode 100755 Snowflake.Data/Core/ArrowChunkParser.cs create mode 100755 Snowflake.Data/Core/BaseResultChunk.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs index 4994ead35..d11eef143 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs @@ -237,12 +237,15 @@ public void TestSimpleLargeResultSet() IDbCommand cmd = conn.CreateCommand(); cmd.CommandText = "select seq4(), uniform(1, 10, 42) from table(generator(rowcount => 1000000)) v order by 1"; - IDataReader reader = cmd.ExecuteReader(); - int counter = 0; - while (reader.Read()) + using (IDataReader reader = cmd.ExecuteReader()) { - Assert.AreEqual(counter.ToString(), reader.GetString(0)); - counter++; + int counter = 0; + while (reader.Read()) + { + Assert.AreEqual(counter.ToString(), reader.GetString(0)); + // don't test the second column as it has random values just to increase the response size + counter++; + } } conn.Close(); } @@ -273,6 +276,7 @@ public void TestUseV1ResultParser() while (reader.Read()) { Assert.AreEqual(counter.ToString(), reader.GetString(0)); + // don't test the second column as it has random values just to increase the response size counter++; } conn.Close(); @@ -302,6 +306,7 @@ public void TestUseV2ChunkDownloader() while (reader.Read()) { Assert.AreEqual(counter.ToString(), reader.GetString(0)); + // don't test the second column as it has random values just to increase the response size counter++; } conn.Close(); @@ -310,6 +315,33 @@ public void TestUseV2ChunkDownloader() SFConfiguration.Instance().ChunkDownloaderVersion = chunkDownloaderVersion; } + [Test] + [Parallelizable(ParallelScope.Children)] + public void TestDefaultChunkDownloaderWithPrefetchThreads([Values(1, 2, 4)] int prefetchThreads) + { + using (SnowflakeDbConnection conn = new SnowflakeDbConnection(ConnectionString)) + { + conn.Open(); + + IDbCommand cmd = conn.CreateCommand(); + cmd.CommandText = $"alter session set CLIENT_PREFETCH_THREADS = {prefetchThreads}"; + cmd.ExecuteNonQuery(); + + // 200000 - empirical value to return 3 additional chunks for both JSON and Arrow response + cmd.CommandText = "select seq4(), uniform(1, 10, 42) from table(generator(rowcount => 200000)) v order by 1"; + + IDataReader reader = cmd.ExecuteReader(); + int counter = 0; + while (reader.Read()) + { + Assert.AreEqual(counter.ToString(), reader.GetString(0)); + // don't test the second column as it has random values just to increase the response size + counter++; + } + conn.Close(); + } + } + [Test] public void TestDataSourceError() { @@ -517,7 +549,6 @@ public void TestRowsAffected() Assert.AreEqual(expectedResult[i], rowsAffected); } } - conn.Close(); } } diff --git a/Snowflake.Data.Tests/IntegrationTests/SFReusableChunkTest.cs b/Snowflake.Data.Tests/IntegrationTests/SFReusableChunkTest.cs index dd7ad5693..07ed0046d 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFReusableChunkTest.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFReusableChunkTest.cs @@ -230,7 +230,7 @@ public TestChunkParserFactory(int exceptionsToThrow) _exceptionsThrown = 0; } - public IChunkParser GetParser(Stream stream) + public IChunkParser GetParser(ResultFormat resultFormat, Stream stream) { if (++_exceptionsThrown <= _expectedExceptionsNumber) return new ThrowingReusableChunkParser(); diff --git a/Snowflake.Data.Tests/UnitTests/ArrowChunkParserTest.cs b/Snowflake.Data.Tests/UnitTests/ArrowChunkParserTest.cs new file mode 100644 index 000000000..ce9cbfb81 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/ArrowChunkParserTest.cs @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + */ + +using System.Linq; +using Apache.Arrow; +using Apache.Arrow.Ipc; + +namespace Snowflake.Data.Tests.UnitTests +{ + using NUnit.Framework; + using Snowflake.Data.Client; + using Snowflake.Data.Configuration; + using Snowflake.Data.Core; + using System; + using System.IO; + using System.Text; + using System.Threading.Tasks; + + [TestFixture, NonParallelizable] + class ArrowChunkParserTest + { + [Test] + [Ignore("ArrowChunkParserTest")] + public void ArrowChunkParserTestDone() + { + // Do nothing - test progress marker + } + + [Test] + public void TestParseChunkReadsRecordBatches([Values(1, 2, 4)] int numberOfRecordBatch) + { + // Arrange + MemoryStream stream = new MemoryStream(); + + for (var i = 0; i < numberOfRecordBatch; i++) + { + var numberOfRecordsInBatch = 10 * i; + var recordBatch = new RecordBatch.Builder() + .Append("Col_Int32", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(1, numberOfRecordsInBatch)))) + .Build(); + + ArrowStreamWriter writer = new ArrowStreamWriter(stream, recordBatch.Schema, true); + writer.WriteRecordBatch(recordBatch); + } + stream.Position = 0; + + var parser = new ArrowChunkParser(stream); + + // Act + var chunk = new ArrowResultChunk(1); + var task = parser.ParseChunk(chunk); + task.Wait(); + + // Assert + Assert.AreEqual(numberOfRecordBatch, chunk.RecordBatch.Count); + for (var i = 0; i < numberOfRecordBatch; i++) + { + var numberOfRecordsInBatch = 10 * i; + Assert.AreEqual(numberOfRecordsInBatch, chunk.RecordBatch[i].Length); + } + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/ArrowResultChunkTest.cs b/Snowflake.Data.Tests/UnitTests/ArrowResultChunkTest.cs index 3d41e1338..642afb275 100755 --- a/Snowflake.Data.Tests/UnitTests/ArrowResultChunkTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ArrowResultChunkTest.cs @@ -13,10 +13,19 @@ namespace Snowflake.Data.Tests.UnitTests [TestFixture] class ArrowResultChunkTest { - private const int RowCount = 10; - private RecordBatch _recordBatch; + private const int RowCountBatchOne = 10; + private const int RowCountBatchTwo = 20; + private readonly RecordBatch _recordBatchOne = new RecordBatch.Builder() + .Append("Col_Int32", false, col => col.Int32( + array => array.AppendRange(Enumerable.Range(1, RowCountBatchOne)))) + .Build(); + private readonly RecordBatch _recordBatchTwo = new RecordBatch.Builder() + .Append("Col_Int32", false, col => col.Int32( + array => array.AppendRange(Enumerable.Range(1, RowCountBatchTwo)))) + .Build(); private ArrowResultChunk _chunk; + [Test] [Ignore("ArrowResultChunkTest")] public void SFArrowResultChunkTestDone() @@ -24,41 +33,142 @@ public void SFArrowResultChunkTestDone() // Do nothing - test progress marker } - [SetUp] - public void BeforeTest() + [Test] + public void TestAddRecordBatchAddsBatchTwo() + { + _chunk = new ArrowResultChunk(_recordBatchOne); + _chunk.AddRecordBatch(_recordBatchTwo); + + Assert.AreEqual(2, _chunk.RecordBatch.Count); + } + + [Test] + public void TestNextIteratesThroughAllRecordsOfOneBatch() { - _recordBatch = new RecordBatch.Builder() - .Append("Col_Int32", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(1, RowCount)))) - .Build(); - _chunk = new ArrowResultChunk(_recordBatch); + _chunk = new ArrowResultChunk(_recordBatchOne); + + for (var i = 0; i < RowCountBatchOne; ++i) + { + Assert.IsTrue(_chunk.Next()); + } + Assert.IsFalse(_chunk.Next()); + } + + [Test] + public void TestNextIteratesThroughAllRecordsOfTwoBatches() + { + _chunk = new ArrowResultChunk(_recordBatchOne); + _chunk.AddRecordBatch(_recordBatchTwo); + + for (var i = 0; i < RowCountBatchOne + RowCountBatchTwo; ++i) + { + Assert.IsTrue(_chunk.Next()); + } + Assert.IsFalse(_chunk.Next()); + } + + [Test] + public void TestRewindIteratesThroughAllRecordsOfBatchOne() + { + _chunk = new ArrowResultChunk(_recordBatchOne); + + // move to the end of the batch + while (_chunk.Next()) {} + + for (var i = 0; i < RowCountBatchOne; ++i) + { + Assert.IsTrue(_chunk.Rewind()); + } + Assert.IsFalse(_chunk.Rewind()); } + [Test] + public void TestRewindIteratesThroughAllRecordsOfTwoBatches() + { + _chunk = new ArrowResultChunk(_recordBatchOne); + _chunk.AddRecordBatch(_recordBatchTwo); + + // move to the end of the batch + while (_chunk.Next()) {} + + for (var i = 0; i < RowCountBatchOne + RowCountBatchTwo; ++i) + { + Assert.IsTrue(_chunk.Rewind()); + } + Assert.IsFalse(_chunk.Rewind()); + } + + [Test] + public void TestResetClearsChunkData() + { + ExecResponseChunk chunkInfo = new ExecResponseChunk() + { + url = "new_url", + uncompressedSize = 100, + rowCount = 2 + }; + _chunk = new ArrowResultChunk(_recordBatchOne); + + _chunk.Reset(chunkInfo, 0); + + Assert.AreEqual(0, _chunk.ChunkIndex); + Assert.AreEqual(chunkInfo.url, _chunk.Url); + Assert.AreEqual(chunkInfo.rowCount, _chunk.RowCount); + } + + [Test] + public void TestExtractCellWithRowParameterReadsAllRows() + { + _chunk = new ArrowResultChunk(_recordBatchOne); + + var column = (Int32Array)_recordBatchOne.Column(0); + for (var i = 0; i < RowCountBatchOne; ++i) + { + var valueFromRecordBatch = column.GetValue(i).ToString(); + Assert.AreEqual(valueFromRecordBatch, _chunk.ExtractCell(i, 0).SafeToString()); + } + } + [Test] public void TestExtractCellReadsAllRows() { - var column = (Int32Array)_recordBatch.Column(0); - for (var i = 0; i < RowCount; ++i) + _chunk = new ArrowResultChunk(_recordBatchOne); + + var column = (Int32Array)_recordBatchOne.Column(0); + for (var i = 0; i < RowCountBatchOne; ++i) { - Assert.AreEqual(column.GetValue(i).ToString(), _chunk.ExtractCell(i, 0).SafeToString()); + var valueFromRecordBatch = column.GetValue(i).ToString(); + + _chunk.Next(); + Assert.AreEqual(valueFromRecordBatch, _chunk.ExtractCell(0).SafeToString()); } } [Test] public void TestExtractCellThrowsOutOfRangeException() { - Assert.Throws(() => _chunk.ExtractCell(RowCount, 0).SafeToString()); + _chunk = new ArrowResultChunk(_recordBatchOne); + + // move to the end of the batch + while (_chunk.Next()) {} + + Assert.Throws(() => _chunk.ExtractCell(0).SafeToString()); } [Test] - public void TestGetRowCountReturnsNumberOfRows() + public void TestRowCountReturnsNumberOfRows() { - Assert.AreEqual(RowCount, _chunk.GetRowCount()); + _chunk = new ArrowResultChunk(_recordBatchOne); + + Assert.AreEqual(RowCountBatchOne, _chunk.RowCount); } [Test] public void TestGetChunkIndexReturnsFirstChunk() { - Assert.AreEqual(0, _chunk.GetChunkIndex()); + _chunk = new ArrowResultChunk(_recordBatchOne); + + Assert.AreEqual(0, _chunk.ChunkIndex); } } diff --git a/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs b/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs index 46a65fdd7..aca014d4d 100755 --- a/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs @@ -92,16 +92,24 @@ public void TestHasRowsReturnsFalseIfNoRows() Assert.IsFalse(_arrowResultSet.HasRows()); } + [Test] + public void TestRewindReturnsFalseBeforeFirstRow() + { + Assert.IsFalse(_arrowResultSet.Rewind()); + } + [Test] public void TestRewindReturnsFalseForFirstRow() { + _arrowResultSet.Next(); // move to first row Assert.IsFalse(_arrowResultSet.Rewind()); } [Test] public void TestRewindReturnsTrueForSecondRowAndMovesToFirstRow() { - _arrowResultSet.Next(); + _arrowResultSet.Next(); // move to first row + _arrowResultSet.Next(); // move to second row Assert.IsTrue(_arrowResultSet.Rewind()); Assert.IsFalse(_arrowResultSet.Rewind()); } @@ -109,8 +117,9 @@ public void TestRewindReturnsTrueForSecondRowAndMovesToFirstRow() [Test] public void TestRewindReturnsTrueForThirdRowAndMovesToFirstRow() { - _arrowResultSet.Next(); - _arrowResultSet.Next(); + _arrowResultSet.Next(); // move to first row + _arrowResultSet.Next(); // move to second row + _arrowResultSet.Next(); // move to third row Assert.IsTrue(_arrowResultSet.Rewind()); Assert.IsTrue(_arrowResultSet.Rewind()); Assert.IsFalse(_arrowResultSet.Rewind()); @@ -140,7 +149,7 @@ private QueryExecResponseData PrepareResponseData(RecordBatch recordBatch) type = "TEXT" }).ToList(), parameters = new List(), - chunks = null, // TODO in SNOW-893835 - add tests with multiple chunks + chunks = null, queryResultFormat = ResultFormat.ARROW, rowsetBase64 = ConvertToBase64String(recordBatch) }; diff --git a/Snowflake.Data.Tests/UnitTests/ChunkDeserializerTest.cs b/Snowflake.Data.Tests/UnitTests/ChunkDeserializerTest.cs index 0b036681b..63cc51f41 100644 --- a/Snowflake.Data.Tests/UnitTests/ChunkDeserializerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ChunkDeserializerTest.cs @@ -41,7 +41,7 @@ public IChunkParser getParser(string data) { byte[] bytes = Encoding.UTF8.GetBytes(data); Stream stream = new MemoryStream(bytes); - return ChunkParserFactory.Instance.GetParser(stream); + return ChunkParserFactory.Instance.GetParser(ResultFormat.JSON, stream); } [Test] @@ -54,10 +54,11 @@ public async Task TestParsingEmptyChunk() SFResultChunk chunk = new SFResultChunk(new string[1, 1]); await parser.ParseChunk(chunk); - - Assert.AreEqual(0, chunk.rowSet.GetLength(0)); // Check row length - Assert.AreEqual(0, chunk.rowSet.GetLength(1)); // Check col length - Assert.Throws(() => chunk.ExtractCell(0, 0).SafeToString()); + chunk.Next(); + + Assert.AreEqual(0, chunk.RowSet.GetLength(0)); // Check row length + Assert.AreEqual(0, chunk.RowSet.GetLength(1)); // Check col length + Assert.Throws(() => chunk.ExtractCell(0).SafeToString()); } [Test] @@ -70,10 +71,11 @@ public async Task TestParsingEmptyArraysInChunk() SFResultChunk chunk = new SFResultChunk(new string[1, 1]); await parser.ParseChunk(chunk); - - Assert.AreEqual(2, chunk.rowSet.GetLength(0)); // Check row length - Assert.AreEqual(0, chunk.rowSet.GetLength(1)); // Check col length - Assert.Throws(() => chunk.ExtractCell(0, 0).SafeToString()); + chunk.Next(); + + Assert.AreEqual(2, chunk.RowSet.GetLength(0)); // Check row length + Assert.AreEqual(0, chunk.RowSet.GetLength(1)); // Check col length + Assert.Throws(() => chunk.ExtractCell(0).SafeToString()); } [Test] @@ -112,13 +114,16 @@ public async Task TestParsingSimpleChunk() SFResultChunk chunk = new SFResultChunk(new string[1, 1]); await parser.ParseChunk(chunk); - - Assert.AreEqual("1", chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual("abcde", chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual("5.678", chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("1", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("abcde", chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("5.678", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); } [Test] @@ -132,12 +137,15 @@ public async Task TestParsingChunkWithNullValue() await parser.ParseChunk(chunk); - Assert.AreEqual(null, chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); + chunk.Next(); + Assert.AreEqual(null, chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); } } } diff --git a/Snowflake.Data.Tests/UnitTests/ChunkDownloaderFactoryTest.cs b/Snowflake.Data.Tests/UnitTests/ChunkDownloaderFactoryTest.cs index bb016376b..a9d994564 100644 --- a/Snowflake.Data.Tests/UnitTests/ChunkDownloaderFactoryTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ChunkDownloaderFactoryTest.cs @@ -38,7 +38,12 @@ private QueryExecResponseData mockQueryRequestData() rowSet = new string[,] { { } }, rowType = new List(), parameters = new List(), - chunks = new List() + chunks = new List{new ExecResponseChunk() + { + url = "fake", + uncompressedSize = 100, + rowCount = 1 + }} }; } diff --git a/Snowflake.Data.Tests/UnitTests/ChunkParserFactoryTest.cs b/Snowflake.Data.Tests/UnitTests/ChunkParserFactoryTest.cs index 715f7c14f..da1c1ae0e 100644 --- a/Snowflake.Data.Tests/UnitTests/ChunkParserFactoryTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ChunkParserFactoryTest.cs @@ -50,12 +50,12 @@ public void TestGetParser([Values(false, true)] bool useV2JsonParser, [Values(1, // GetParser() throws an error when ChunkParserVersion is not 1-3 if (chunkParserVersion == 4 && !useV2JsonParser) { - Exception ex = Assert.Throws(() => parser = ChunkParserFactory.Instance.GetParser(stream)); + Exception ex = Assert.Throws(() => parser = ChunkParserFactory.Instance.GetParser(ResultFormat.JSON, stream)); Assert.AreEqual("Unsupported Chunk Parser version specified in the SFConfiguration", ex.Message); } else { - parser = ChunkParserFactory.Instance.GetParser(stream); + parser = ChunkParserFactory.Instance.GetParser(ResultFormat.JSON, stream); } // GetParser() returns ChunkDeserializer when UseV2JsonParser is true diff --git a/Snowflake.Data.Tests/UnitTests/ChunkStreamingParserTest.cs b/Snowflake.Data.Tests/UnitTests/ChunkStreamingParserTest.cs index b49da7ba5..e1194aaa1 100644 --- a/Snowflake.Data.Tests/UnitTests/ChunkStreamingParserTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ChunkStreamingParserTest.cs @@ -41,7 +41,7 @@ public IChunkParser getParser(string data) { byte[] bytes = Encoding.UTF8.GetBytes(data); Stream stream = new MemoryStream(bytes); - return ChunkParserFactory.Instance.GetParser(stream); + return ChunkParserFactory.Instance.GetParser(ResultFormat.JSON, stream); } [Test] @@ -54,10 +54,11 @@ public async Task TestParsingEmptyChunk() SFResultChunk chunk = new SFResultChunk(new string[0, 0]); await parser.ParseChunk(chunk); + chunk.Next(); - Assert.AreEqual(0, chunk.rowSet.GetLength(0)); // Check row length - Assert.AreEqual(0, chunk.rowSet.GetLength(0)); // Check col length - Assert.Throws(() => chunk.ExtractCell(0, 0).SafeToString()); + Assert.AreEqual(0, chunk.RowSet.GetLength(0)); // Check row length + Assert.AreEqual(0, chunk.RowSet.GetLength(0)); // Check col length + Assert.Throws(() => chunk.ExtractCell(0).SafeToString()); } [Test] @@ -70,10 +71,11 @@ public async Task TestParsingEmptyArraysInChunk() SFResultChunk chunk = new SFResultChunk(new string[2, 0]); await parser.ParseChunk(chunk); - - Assert.AreEqual(2, chunk.rowSet.GetLength(0)); // Check row length - Assert.AreEqual(0, chunk.rowSet.GetLength(1)); // Check col length - Assert.Throws(() => chunk.ExtractCell(0, 0).SafeToString()); + chunk.Next(); + + Assert.AreEqual(2, chunk.RowSet.GetLength(0)); // Check row length + Assert.AreEqual(0, chunk.RowSet.GetLength(1)); // Check col length + Assert.Throws(() => chunk.ExtractCell(0).SafeToString()); } [Test] @@ -86,11 +88,12 @@ public async Task TestParsingNonJsonArrayChunk() SFResultChunk chunk = new SFResultChunk(new string[1, 3]); await parser.ParseChunk(chunk); + chunk.Next(); // ChunkStreamingParser is able to parse one row that is not inside of an array - Assert.AreEqual("1", chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual("abcde", chunk.ExtractCell(0, 2).SafeToString()); + Assert.AreEqual("1", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("abcde", chunk.ExtractCell(2).SafeToString()); } [Test] @@ -131,12 +134,15 @@ public async Task TestParsingSimpleChunk() await parser.ParseChunk(chunk); - Assert.AreEqual("1", chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual("abcde", chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual("5.678", chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); + chunk.Next(); + Assert.AreEqual("1", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("abcde", chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("5.678", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); } [Test] @@ -150,12 +156,15 @@ public async Task TestParsingChunkWithNullValue() await parser.ParseChunk(chunk); - Assert.AreEqual(null, chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); + chunk.Next(); + Assert.AreEqual(null, chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); } } } diff --git a/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs b/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs index 91daa8aa4..64ac8d1a2 100755 --- a/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFReusableChunkTest.cs @@ -25,144 +25,104 @@ public void ReusableChunkTestDone() } [Test] - public async Task TestSimpleChunk() + public void TestExtractCellWithRowParameterReadsAllRows() { string data = "[ [\"1\", \"1.234\", \"abcde\"], [\"2\", \"5.678\", \"fghi\"] ]"; - byte[] bytes = Encoding.UTF8.GetBytes(data); - Stream stream = new MemoryStream(bytes); - IChunkParser parser = new ReusableChunkParser(stream); + var chunk = PrepareChunkAsync(data, 3, 2).Result; - ExecResponseChunk chunkInfo = new ExecResponseChunk() - { - url = "fake", - uncompressedSize = 100, - rowCount = 2 - }; - - SFReusableChunk chunk = new SFReusableChunk(3); - chunk.Reset(chunkInfo, 0); - - await parser.ParseChunk(chunk); - - Assert.AreEqual("1", chunk.ExtractCell(0, 0).SafeToString()); + Assert.AreEqual("1", chunk.ExtractCell(0,0).SafeToString()); Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); Assert.AreEqual("abcde", chunk.ExtractCell(0, 2).SafeToString()); + Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); Assert.AreEqual("5.678", chunk.ExtractCell(1, 1).SafeToString()); Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); } [Test] - public async Task TestChunkWithNull() + public void TestSimpleChunk() { - string data = "[ [null, \"1.234\", null], [\"2\", null, \"fghi\"] ]"; - byte[] bytes = Encoding.UTF8.GetBytes(data); - Stream stream = new MemoryStream(bytes); - IChunkParser parser = new ReusableChunkParser(stream); - - ExecResponseChunk chunkInfo = new ExecResponseChunk() - { - url = "fake", - uncompressedSize = 100, - rowCount = 2 - }; - - SFReusableChunk chunk = new SFReusableChunk(3); - chunk.Reset(chunkInfo, 0); - - await parser.ParseChunk(chunk); + string data = "[ [\"1\", \"1.234\", \"abcde\"], [\"2\", \"5.678\", \"fghi\"] ]"; + var chunk = PrepareChunkAsync(data, 3, 2).Result; + + chunk.Next(); + Assert.AreEqual("1", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("abcde", chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("5.678", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); + } - Assert.AreEqual(null, chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); + [Test] + public void TestChunkWithNull() + { + string data = "[ [null, \"1.234\", null], [\"2\", null, \"fghi\"] ]"; + var chunk = PrepareChunkAsync(data, 3, 2).Result; + + chunk.Next(); + Assert.AreEqual(null, chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); } [Test] - public async Task TestChunkWithDate() + public void TestChunkWithDate() { string data = "[ [null, \"2019-08-21T11:58:00\", null], [\"2\", null, \"fghi\"] ]"; - byte[] bytes = Encoding.UTF8.GetBytes(data); - Stream stream = new MemoryStream(bytes); - IChunkParser parser = new ReusableChunkParser(stream); - - ExecResponseChunk chunkInfo = new ExecResponseChunk() - { - url = "fake", - uncompressedSize = 100, - rowCount = 2 - }; - - SFReusableChunk chunk = new SFReusableChunk(3); - chunk.Reset(chunkInfo, 0); - - await parser.ParseChunk(chunk); - - Assert.AreEqual(null, chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("2019-08-21T11:58:00", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); + var chunk = PrepareChunkAsync(data, 3, 2).Result; + + chunk.Next(); + Assert.AreEqual(null, chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("2019-08-21T11:58:00", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); } [Test] - public async Task TestChunkWithEscape() + public void TestChunkWithEscape() { string data = "[ [\"\\\\åäö\\nÅÄÖ\\r\", \"1.234\", null], [\"2\", null, \"fghi\"] ]"; - byte[] bytes = Encoding.UTF8.GetBytes(data); - Stream stream = new MemoryStream(bytes); - IChunkParser parser = new ReusableChunkParser(stream); - - ExecResponseChunk chunkInfo = new ExecResponseChunk() - { - url = "fake", - uncompressedSize = bytes.Length, - rowCount = 2 - }; - - SFReusableChunk chunk = new SFReusableChunk(3); - chunk.Reset(chunkInfo, 0); - - await parser.ParseChunk(chunk); - - Assert.AreEqual("\\åäö\nÅÄÖ\r", chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual("fghi", chunk.ExtractCell(1, 2).SafeToString()); + var chunk = PrepareChunkAsync(data, 3, 2).Result; + + chunk.Next(); + Assert.AreEqual("\\åäö\nÅÄÖ\r", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual("fghi", chunk.ExtractCell(2).SafeToString()); } [Test] - public async Task TestChunkWithLongString() + public void TestChunkWithLongString() { string longstring = new string('å', 10 * 1000 * 1000); string data = "[ [\"åäö\\nÅÄÖ\\r\", \"1.234\", null], [\"2\", null, \"" + longstring + "\"] ]"; - byte[] bytes = Encoding.UTF8.GetBytes(data); - Stream stream = new MemoryStream(bytes); - IChunkParser parser = new ReusableChunkParser(stream); - - ExecResponseChunk chunkInfo = new ExecResponseChunk() - { - url = "fake", - uncompressedSize = bytes.Length, - rowCount = 2 - }; - - SFReusableChunk chunk = new SFReusableChunk(3); - chunk.Reset(chunkInfo, 0); - - await parser.ParseChunk(chunk); - - Assert.AreEqual("åäö\nÅÄÖ\r", chunk.ExtractCell(0, 0).SafeToString()); - Assert.AreEqual("1.234", chunk.ExtractCell(0, 1).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(0, 2).SafeToString()); - Assert.AreEqual("2", chunk.ExtractCell(1, 0).SafeToString()); - Assert.AreEqual(null, chunk.ExtractCell(1, 1).SafeToString()); - Assert.AreEqual(longstring, chunk.ExtractCell(1, 2).SafeToString()); + var chunk = PrepareChunkAsync(data, 3, 2).Result; + + chunk.Next(); + Assert.AreEqual("åäö\nÅÄÖ\r", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual("1.234", chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(2).SafeToString()); + + chunk.Next(); + Assert.AreEqual("2", chunk.ExtractCell(0).SafeToString()); + Assert.AreEqual(null, chunk.ExtractCell(1).SafeToString()); + Assert.AreEqual(longstring, chunk.ExtractCell(2).SafeToString()); } [Test] @@ -170,23 +130,10 @@ public async Task TestParserError1() { // Unterminated escape sequence string data = "[ [\"åäö\\"; - byte[] bytes = Encoding.UTF8.GetBytes(data); - Stream stream = new MemoryStream(bytes); - IChunkParser parser = new ReusableChunkParser(stream); - - ExecResponseChunk chunkInfo = new ExecResponseChunk() - { - url = "fake", - uncompressedSize = bytes.Length, - rowCount = 1 - }; - - SFReusableChunk chunk = new SFReusableChunk(1); - chunk.Reset(chunkInfo, 0); try { - await parser.ParseChunk(chunk); + await PrepareChunkAsync(data, 1, 1); Assert.Fail(); } catch (SnowflakeDbException e) @@ -200,23 +147,10 @@ public async Task TestParserError2() { // Unterminated string string data = "[ [\"åäö"; - byte[] bytes = Encoding.UTF8.GetBytes(data); - Stream stream = new MemoryStream(bytes); - IChunkParser parser = new ReusableChunkParser(stream); - - ExecResponseChunk chunkInfo = new ExecResponseChunk() - { - url = "fake", - uncompressedSize = bytes.Length, - rowCount = 1 - }; - - SFReusableChunk chunk = new SFReusableChunk(1); - chunk.Reset(chunkInfo, 0); - + try { - await parser.ParseChunk(chunk); + await PrepareChunkAsync(data, 1, 1); Assert.Fail(); } catch (SnowflakeDbException e) @@ -226,11 +160,75 @@ public async Task TestParserError2() } [Test] - public async Task TestParserWithTab() + public void TestParserWithTab() { // Unterminated string string data = "[[\"abc\t\"]]"; - byte[] bytes = Encoding.UTF8.GetBytes(data); + var chunk = PrepareChunkAsync(data, 1, 1).Result; + + chunk.Next(); + string val = chunk.ExtractCell(0).SafeToString(); + Assert.AreEqual("abc\t", chunk.ExtractCell(0).SafeToString()); + } + + [Test] + public void TestNextIteratesThroughAllRecords() + { + const int RowCount = 3; + string data = "[ [\"1\"], [\"2\"], [\"3\"] ]"; + var chunk = PrepareChunkAsync(data, 1, RowCount).Result; + + for (var i = 0; i < RowCount; ++i) + { + Assert.IsTrue(chunk.Next()); + } + Assert.IsFalse(chunk.Next()); + } + + [Test] + public void TestRewindIteratesThroughAllRecords() + { + const int RowCount = 3; + string data = "[ [\"1\"], [\"2\"], [\"3\"] ]"; + var chunk = PrepareChunkAsync(data, 1, RowCount).Result; + + for (var i = 0; i < RowCount; ++i) + { + chunk.Next(); + } + chunk.Next(); + + for (var i = 0; i < RowCount; ++i) + { + Assert.IsTrue(chunk.Rewind()); + } + Assert.IsFalse(chunk.Rewind()); + } + + [Test] + public void TestResetClearsChunkData() + { + const int RowCount = 3; + string data = "[ [\"1\"], [\"2\"], [\"3\"] ]"; + var chunk = PrepareChunkAsync(data, 1, RowCount).Result; + + ExecResponseChunk chunkInfo = new ExecResponseChunk() + { + url = "new_url", + uncompressedSize = 100, + rowCount = 200 + }; + + chunk.Reset(chunkInfo, 0); + + Assert.AreEqual(0, chunk.ChunkIndex); + Assert.AreEqual(chunkInfo.url, chunk.Url); + Assert.AreEqual(chunkInfo.rowCount, chunk.RowCount); + } + + private async Task PrepareChunkAsync(string stringData, int colCount, int rowCount) + { + byte[] bytes = Encoding.UTF8.GetBytes(stringData); Stream stream = new MemoryStream(bytes); IChunkParser parser = new ReusableChunkParser(stream); @@ -238,16 +236,14 @@ public async Task TestParserWithTab() { url = "fake", uncompressedSize = bytes.Length, - rowCount = 1 + rowCount = rowCount }; - SFReusableChunk chunk = new SFReusableChunk(1); + SFReusableChunk chunk = new SFReusableChunk(colCount); chunk.Reset(chunkInfo, 0); - + await parser.ParseChunk(chunk); - string val = chunk.ExtractCell(0, 0).SafeToString(); - Assert.AreEqual("abc\t", chunk.ExtractCell(0, 0).SafeToString()); + return chunk; } - } } diff --git a/Snowflake.Data/Core/ArrowChunkParser.cs b/Snowflake.Data/Core/ArrowChunkParser.cs new file mode 100755 index 000000000..892861f0e --- /dev/null +++ b/Snowflake.Data/Core/ArrowChunkParser.cs @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + */ + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow; +using Apache.Arrow.Ipc; + +namespace Snowflake.Data.Core +{ + public class ArrowChunkParser : IChunkParser + { + private readonly Stream stream; + + internal ArrowChunkParser(Stream stream) + { + this.stream = stream; + } + + public async Task ParseChunk(IResultChunk chunk) + { + ArrowResultChunk resultChunk = (ArrowResultChunk)chunk; + + using (var reader = new ArrowStreamReader(stream)) + { + RecordBatch recordBatch; + while ((recordBatch = await reader.ReadNextRecordBatchAsync().ConfigureAwait(false)) != null) + { + resultChunk.AddRecordBatch(recordBatch); + } + } + } + } +} diff --git a/Snowflake.Data/Core/ArrowResultChunk.cs b/Snowflake.Data/Core/ArrowResultChunk.cs index 75b29a4a2..b226ec129 100755 --- a/Snowflake.Data/Core/ArrowResultChunk.cs +++ b/Snowflake.Data/Core/ArrowResultChunk.cs @@ -3,38 +3,105 @@ */ using System; +using System.Collections.Generic; using System.Text; using Apache.Arrow; using Apache.Arrow.Types; namespace Snowflake.Data.Core { - internal class ArrowResultChunk : IResultChunk + internal class ArrowResultChunk : BaseResultChunk { - public RecordBatch RecordBatch { get; set; } + internal override ResultFormat Format => ResultFormat.ARROW; - private int _rowCount; - private int _colCount; - private int _chunkIndex; + public List RecordBatch { get; set; } + private int _currentBatchIndex = 0; + private int _currentRecordIndex = -1; + public ArrowResultChunk(RecordBatch recordBatch) { - RecordBatch = recordBatch; + RecordBatch = new List{recordBatch}; - _rowCount = recordBatch.Length; - _colCount = recordBatch.ColumnCount; - _chunkIndex = 0; + RowCount = recordBatch.Length; + ColumnCount = recordBatch.ColumnCount; + ChunkIndex = 0; + } + + public ArrowResultChunk(int columnCount) + { + RecordBatch = new List(); + + RowCount = 0; + ColumnCount = columnCount; + ChunkIndex = 0; + } + + public void AddRecordBatch(RecordBatch recordBatch) + { + RecordBatch.Add(recordBatch); } - public UTF8Buffer ExtractCell(int rowIndex, int columnIndex) + internal override void Reset(ExecResponseChunk chunkInfo, int chunkIndex) { - var column = RecordBatch.Column(columnIndex); + base.Reset(chunkInfo, chunkIndex); + + _currentBatchIndex = 0; + _currentRecordIndex = -1; + RecordBatch.Clear(); + } + + internal override bool Next() + { + _currentRecordIndex += 1; + if (_currentRecordIndex < RecordBatch[_currentBatchIndex].Length) + return true; + + _currentBatchIndex += 1; + _currentRecordIndex = 0; + + return _currentBatchIndex < RecordBatch.Count; + } + + internal override bool Rewind() + { + _currentRecordIndex -= 1; + if (_currentRecordIndex >= 0) + return true; + + _currentBatchIndex -= 1; + + if (_currentBatchIndex >= 0) + { + _currentRecordIndex = RecordBatch[_currentBatchIndex].Length - 1; + return true; + } + + return false; + } + + public override UTF8Buffer ExtractCell(int rowIndex, int columnIndex) + { + _currentBatchIndex = 0; + _currentRecordIndex = rowIndex; + while (_currentRecordIndex >= RecordBatch[_currentBatchIndex].Length) + { + _currentRecordIndex -= RecordBatch[_currentBatchIndex].Length; + _currentBatchIndex += 1; + } + + return ExtractCell(columnIndex); + } + + public override UTF8Buffer ExtractCell(int columnIndex) + { + var column = RecordBatch[_currentBatchIndex].Column(columnIndex); string stringBuffer; switch (column.Data.DataType.TypeId) { case ArrowTypeId.Int32: - stringBuffer = ((Int32Array)column).GetValue(rowIndex).ToString(); + stringBuffer = ((Int32Array)column).GetValue(_currentRecordIndex).ToString(); break; // TODO in SNOW-893834 - other types @@ -48,16 +115,6 @@ public UTF8Buffer ExtractCell(int rowIndex, int columnIndex) return new UTF8Buffer(Encoding.UTF8.GetBytes(stringBuffer)); } - - public int GetRowCount() - { - return _rowCount; - } - - public int GetChunkIndex() - { - return _chunkIndex; - } } } diff --git a/Snowflake.Data/Core/ArrowResultSet.cs b/Snowflake.Data/Core/ArrowResultSet.cs index fe561849d..5c3103320 100755 --- a/Snowflake.Data/Core/ArrowResultSet.cs +++ b/Snowflake.Data/Core/ArrowResultSet.cs @@ -16,10 +16,8 @@ class ArrowResultSet : SFBaseResultSet { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - private int _currentChunkRowIdx = -1; - private int _currentChunkRowCount; private readonly int _totalChunkCount; - private IResultChunk _currentChunk; + private BaseResultChunk _currentChunk; private readonly IChunkDownloader _chunkDownloader; public ArrowResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() @@ -32,7 +30,6 @@ public ArrowResultSet(QueryExecResponseData responseData, SFStatement sfStatemen using (var reader = new ArrowStreamReader(stream)) { var recordBatch = reader.ReadNextRecordBatch(); - _currentChunkRowCount = recordBatch.Length; _currentChunk = new ArrowResultChunk(recordBatch); } } @@ -43,9 +40,7 @@ public ArrowResultSet(QueryExecResponseData responseData, SFStatement sfStatemen if (responseData.chunks != null) { _totalChunkCount = responseData.chunks.Count; - - // TODO in SNOW-893835 - support for multiple chunks - throw new SnowflakeDbException(SFError.UNSUPPORTED_FEATURE); + _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); } responseData.rowSet = null; @@ -67,16 +62,14 @@ internal override async Task NextAsync() { ThrowIfClosed(); - _currentChunkRowIdx++; - if (_currentChunkRowIdx < _currentChunkRowCount) - { + if (_currentChunk.Next()) return true; - } if (_totalChunkCount > 0) { - // TODO in SNOW-893835 - support for multiple chunks - throw new SnowflakeDbException(SFError.UNSUPPORTED_FEATURE); + s_logger.Debug("Get next chunk from chunk downloader"); + _currentChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); + return _currentChunk?.Next() ?? false; } return false; @@ -86,16 +79,14 @@ internal override bool Next() { ThrowIfClosed(); - _currentChunkRowIdx++; - if (_currentChunkRowIdx < _currentChunkRowCount) - { + if (_currentChunk.Next()) return true; - } - + if (_totalChunkCount > 0) { - // TODO in SNOW-893835 - support for multiple chunks - throw new SnowflakeDbException(SFError.UNSUPPORTED_FEATURE); + s_logger.Debug("Get next chunk from chunk downloader"); + _currentChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; + return _currentChunk?.Next() ?? false; } return false; @@ -118,7 +109,7 @@ internal override bool HasRows() return false; } - return _currentChunkRowCount > 0 || _totalChunkCount > 0; + return _currentChunk.RowCount > 0 || _totalChunkCount > 0; } /// @@ -129,11 +120,12 @@ internal override bool Rewind() { ThrowIfClosed(); - if (_currentChunkRowIdx >= 0) - { - // TODO in SNOW-893835 - rewind - _currentChunkRowIdx--; + if (_currentChunk.Rewind()) return true; + + if (_currentChunk.ChunkIndex > 0) + { + s_logger.Warn("Unable to rewind to the previous chunk"); } return false; @@ -148,7 +140,7 @@ internal override UTF8Buffer getObjectInternal(int columnIndex) throw new SnowflakeDbException(SFError.COLUMN_INDEX_OUT_OF_BOUND, columnIndex); } - return _currentChunk.ExtractCell(_currentChunkRowIdx, columnIndex); + return _currentChunk.ExtractCell(columnIndex); } private void UpdateSessionStatus(QueryExecResponseData responseData) @@ -157,13 +149,5 @@ private void UpdateSessionStatus(QueryExecResponseData responseData) session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); session.UpdateSessionParameterMap(responseData.parameters); } - - private void ThrowIfClosed() - { - if (isClosed) - { - throw new SnowflakeDbException(SFError.DATA_READER_ALREADY_CLOSED); - } - } } } diff --git a/Snowflake.Data/Core/BaseResultChunk.cs b/Snowflake.Data/Core/BaseResultChunk.cs new file mode 100755 index 000000000..da0591364 --- /dev/null +++ b/Snowflake.Data/Core/BaseResultChunk.cs @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Core +{ + public abstract class BaseResultChunk : IResultChunk + { + internal abstract ResultFormat Format { get; } + + public int RowCount { get; protected set; } + + public int ColumnCount { get; protected set; } + + public int ChunkIndex { get; protected set; } + + internal string Url { get; set; } + + internal string[,] RowSet { get; set; } + + public int GetRowCount() => RowCount; + + public int GetChunkIndex() => ChunkIndex; + + public abstract UTF8Buffer ExtractCell(int rowIndex, int columnIndex); + + public abstract UTF8Buffer ExtractCell(int columnIndex); + + internal abstract bool Next(); + + internal abstract bool Rewind(); + + internal virtual void Reset(ExecResponseChunk chunkInfo, int chunkIndex) + { + RowCount = chunkInfo.rowCount; + Url = chunkInfo.url; + ChunkIndex = chunkIndex; + } + + internal virtual void ResetForRetry() + { + } + } +} diff --git a/Snowflake.Data/Core/ChunkDeserializer.cs b/Snowflake.Data/Core/ChunkDeserializer.cs index 32357536d..385b1ddbe 100755 --- a/Snowflake.Data/Core/ChunkDeserializer.cs +++ b/Snowflake.Data/Core/ChunkDeserializer.cs @@ -27,7 +27,7 @@ await Task.Run(() => using (StreamReader sr = new StreamReader(stream)) using (JsonTextReader jr = new JsonTextReader(sr)) { - ((SFResultChunk)chunk).rowSet = JsonSerializer.Deserialize(jr); + ((SFResultChunk)chunk).RowSet = JsonSerializer.Deserialize(jr); } }); } diff --git a/Snowflake.Data/Core/ChunkDownloaderFactory.cs b/Snowflake.Data/Core/ChunkDownloaderFactory.cs index 92f7bf2a5..b684b06a3 100755 --- a/Snowflake.Data/Core/ChunkDownloaderFactory.cs +++ b/Snowflake.Data/Core/ChunkDownloaderFactory.cs @@ -36,7 +36,8 @@ public static IChunkDownloader GetDownloader(QueryExecResponseData responseData, responseData.qrmk, responseData.chunkHeaders, cancellationToken, - resultSet); + resultSet, + responseData.queryResultFormat); default: throw new Exception("Unsupported Chunk Downloader version specified in the SFConfiguration"); } diff --git a/Snowflake.Data/Core/ChunkParserFactory.cs b/Snowflake.Data/Core/ChunkParserFactory.cs index c571d77c5..2bd9b4526 100755 --- a/Snowflake.Data/Core/ChunkParserFactory.cs +++ b/Snowflake.Data/Core/ChunkParserFactory.cs @@ -12,8 +12,11 @@ class ChunkParserFactory : IChunkParserFactory { public static IChunkParserFactory Instance = new ChunkParserFactory(); - public IChunkParser GetParser(Stream stream) + public IChunkParser GetParser(ResultFormat resultFormat, Stream stream) { + if (resultFormat == ResultFormat.ARROW) + return new ArrowChunkParser(stream); + switch (SFConfiguration.Instance().GetChunkParserVersion()) { case 1: diff --git a/Snowflake.Data/Core/ChunkStreamingParser.cs b/Snowflake.Data/Core/ChunkStreamingParser.cs index 4ecfd89ce..d37d6dec3 100755 --- a/Snowflake.Data/Core/ChunkStreamingParser.cs +++ b/Snowflake.Data/Core/ChunkStreamingParser.cs @@ -23,46 +23,45 @@ public async Task ParseChunk(IResultChunk chunk) { await Task.Run(() => { + // parse results row by row + using (StreamReader sr = new StreamReader(stream)) + using (JsonTextReader jr = new JsonTextReader(sr) { DateParseHandling = DateParseHandling.None }) + { + int row = 0; + int col = 0; - // parse results row by row - using (StreamReader sr = new StreamReader(stream)) - using (JsonTextReader jr = new JsonTextReader(sr) { DateParseHandling = DateParseHandling.None }) - { - int row = 0; - int col = 0; - - var outputMatrix = new string[chunk.GetRowCount(), ((SFResultChunk)chunk).colCount]; + var outputMatrix = new string[chunk.GetRowCount(), ((SFResultChunk)chunk).ColumnCount]; - while (jr.Read()) - { - switch (jr.TokenType) + while (jr.Read()) { - case JsonToken.StartArray: - case JsonToken.None: - break; + switch (jr.TokenType) + { + case JsonToken.StartArray: + case JsonToken.None: + break; - case JsonToken.EndArray: - if (col > 0) - { - col = 0; - row++; - } + case JsonToken.EndArray: + if (col > 0) + { + col = 0; + row++; + } - break; + break; - case JsonToken.Null: - outputMatrix[row, col++] = null; - break; + case JsonToken.Null: + outputMatrix[row, col++] = null; + break; - case JsonToken.String: - outputMatrix[row, col++] = (string)jr.Value; - break; + case JsonToken.String: + outputMatrix[row, col++] = (string)jr.Value; + break; - default: - throw new SnowflakeDbException(SFError.INTERNAL_ERROR, $"Unexpected token type: {jr.TokenType}"); + default: + throw new SnowflakeDbException(SFError.INTERNAL_ERROR, $"Unexpected token type: {jr.TokenType}"); + } } - } - ((SFResultChunk)chunk).rowSet = outputMatrix; + ((SFResultChunk)chunk).RowSet = outputMatrix; } }); } diff --git a/Snowflake.Data/Core/IChunkDownloader.cs b/Snowflake.Data/Core/IChunkDownloader.cs index a5203e86a..0373cbf7c 100755 --- a/Snowflake.Data/Core/IChunkDownloader.cs +++ b/Snowflake.Data/Core/IChunkDownloader.cs @@ -8,6 +8,6 @@ namespace Snowflake.Data.Core { interface IChunkDownloader { - Task GetNextChunkAsync(); + Task GetNextChunkAsync(); } } diff --git a/Snowflake.Data/Core/IChunkParserFactory.cs b/Snowflake.Data/Core/IChunkParserFactory.cs index 09355de28..1d93d3d78 100644 --- a/Snowflake.Data/Core/IChunkParserFactory.cs +++ b/Snowflake.Data/Core/IChunkParserFactory.cs @@ -4,6 +4,6 @@ namespace Snowflake.Data.Core { internal interface IChunkParserFactory { - IChunkParser GetParser(Stream stream); + IChunkParser GetParser(ResultFormat resultFormat, Stream stream); } } diff --git a/Snowflake.Data/Core/IResultChunk.cs b/Snowflake.Data/Core/IResultChunk.cs index 86ca84258..662bc6999 100755 --- a/Snowflake.Data/Core/IResultChunk.cs +++ b/Snowflake.Data/Core/IResultChunk.cs @@ -1,21 +1,21 @@ -/* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. - */ - -namespace Snowflake.Data.Core -{ - public enum ResultFormat // TODO add tests for ResultFormat - { - JSON, - ARROW - } - - public interface IResultChunk - { - UTF8Buffer ExtractCell(int rowIndex, int columnIndex); - - int GetRowCount(); - - int GetChunkIndex(); - } -} +/* + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Core +{ + public enum ResultFormat + { + JSON, + ARROW + } + + public interface IResultChunk + { + UTF8Buffer ExtractCell(int rowIndex, int columnIndex); + + int GetRowCount(); + + int GetChunkIndex(); + } +} \ No newline at end of file diff --git a/Snowflake.Data/Core/ReusableChunkParser.cs b/Snowflake.Data/Core/ReusableChunkParser.cs index c971d76fa..ea3d8d3bd 100755 --- a/Snowflake.Data/Core/ReusableChunkParser.cs +++ b/Snowflake.Data/Core/ReusableChunkParser.cs @@ -139,9 +139,9 @@ await Task.Run(() => } // The 'u' case already writes to stream so skip to prevent re-writing // If not skipped, unicode characters are added an extra u (e.g "/u007f" becomes "/u007fu") - if (!caseU) - { - ms.WriteByte((byte)c); + if (!caseU) + { + ms.WriteByte((byte)c); } } else diff --git a/Snowflake.Data/Core/SFBaseResultSet.cs b/Snowflake.Data/Core/SFBaseResultSet.cs index bebb81974..5b7422c26 100755 --- a/Snowflake.Data/Core/SFBaseResultSet.cs +++ b/Snowflake.Data/Core/SFBaseResultSet.cs @@ -6,6 +6,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Snowflake.Data.Client; namespace Snowflake.Data.Core { @@ -85,5 +86,13 @@ internal void close() isClosed = true; } + internal void ThrowIfClosed() + { + if (isClosed) + { + throw new SnowflakeDbException(SFError.DATA_READER_ALREADY_CLOSED); + } + } + } } diff --git a/Snowflake.Data/Core/SFBlockingChunkDownloader.cs b/Snowflake.Data/Core/SFBlockingChunkDownloader.cs index ccc13174c..3049e5ea5 100755 --- a/Snowflake.Data/Core/SFBlockingChunkDownloader.cs +++ b/Snowflake.Data/Core/SFBlockingChunkDownloader.cs @@ -72,11 +72,11 @@ private int GetPrefetchThreads(SFBaseResultSet resultSet) return Int32.Parse(val); } - private BlockingCollection> _downloadTasks; + private BlockingCollection> _downloadTasks; private void FillDownloads() { - _downloadTasks = new BlockingCollection>(prefetchThreads); + _downloadTasks = new BlockingCollection>(prefetchThreads); Task.Run(() => { @@ -96,11 +96,11 @@ private void FillDownloads() }); } - public Task GetNextChunkAsync() + public Task GetNextChunkAsync() { if (_downloadTasks.IsCompleted) { - return Task.FromResult(null); + return Task.FromResult(null); } else { @@ -108,17 +108,15 @@ public Task GetNextChunkAsync() } } - private async Task DownloadChunkAsync(DownloadContext downloadContext) + private async Task DownloadChunkAsync(DownloadContext downloadContext) { - logger.Info($"Start donwloading chunk #{downloadContext.chunkIndex}"); - SFResultChunk chunk = downloadContext.chunk; - - chunk.downloadState = DownloadState.IN_PROGRESS; + logger.Info($"Start downloading chunk #{downloadContext.chunkIndex}"); + BaseResultChunk chunk = downloadContext.chunk; S3DownloadRequest downloadRequest = new S3DownloadRequest() { - Url = new UriBuilder(chunk.url).Uri, + Url = new UriBuilder(chunk.Url).Uri, qrmk = downloadContext.qrmk, // s3 download request timeout to one hour RestTimeout = TimeSpan.FromHours(1), @@ -141,7 +139,6 @@ private async Task DownloadChunkAsync(DownloadContext downloadCont ParseStreamIntoChunk(stream, chunk); - chunk.downloadState = DownloadState.SUCCESS; logger.Info($"Succeed downloading chunk #{downloadContext.chunkIndex}"); return chunk; @@ -157,21 +154,21 @@ private async Task DownloadChunkAsync(DownloadContext downloadCont /// /// /// - private void ParseStreamIntoChunk(Stream content, SFResultChunk resultChunk) + private void ParseStreamIntoChunk(Stream content, BaseResultChunk resultChunk) { Stream openBracket = new MemoryStream(Encoding.UTF8.GetBytes("[")); Stream closeBracket = new MemoryStream(Encoding.UTF8.GetBytes("]")); Stream concatStream = new ConcatenatedStream(new Stream[3] { openBracket, content, closeBracket}); - IChunkParser parser = ChunkParserFactory.Instance.GetParser(concatStream); + IChunkParser parser = ChunkParserFactory.Instance.GetParser(resultChunk.Format, concatStream); parser.ParseChunk(resultChunk); } } class DownloadContext { - public SFResultChunk chunk { get; set; } + public BaseResultChunk chunk { get; set; } public int chunkIndex { get; set; } diff --git a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs index 1930137d4..32c19caab 100755 --- a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs +++ b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs @@ -24,7 +24,7 @@ class SFBlockingChunkDownloaderV3 : IChunkDownloader { static private SFLogger logger = SFLoggerFactory.GetLogger(); - private List chunkDatas = new List(); + private List chunkDatas = new List(); private string qrmk; @@ -47,13 +47,14 @@ class SFBlockingChunkDownloaderV3 : IChunkDownloader private readonly List chunkInfos; - private readonly List> taskQueues; + private readonly List> taskQueues; public SFBlockingChunkDownloaderV3(int colCount, List chunkInfos, string qrmk, Dictionary chunkHeaders, CancellationToken cancellationToken, - SFBaseResultSet ResultSet) + SFBaseResultSet ResultSet, + ResultFormat resultFormat) { this.qrmk = qrmk; this.chunkHeaders = chunkHeaders; @@ -64,18 +65,22 @@ public SFBlockingChunkDownloaderV3(int colCount, this.prefetchSlot = Math.Min(chunkInfos.Count, GetPrefetchThreads(ResultSet)); this.chunkInfos = chunkInfos; this.nextChunkToConsumeIndex = 0; - this.taskQueues = new List>(); + this.taskQueues = new List>(); externalCancellationToken = cancellationToken; for (int i=0; i GetNextChunkAsync() - { - return _downloadTasks.IsCompleted ? Task.FromResult(null) : _downloadTasks.Take(); - }*/ - - public async Task GetNextChunkAsync() + public async Task GetNextChunkAsync() { logger.Info($"NextChunkToConsume: {nextChunkToConsumeIndex}, NextChunkToDownload: {nextChunkToDownloadIndex}"); if (nextChunkToConsumeIndex < chunkInfos.Count) { - Task chunk = taskQueues[nextChunkToConsumeIndex % prefetchSlot]; + Task chunk = taskQueues[nextChunkToConsumeIndex % prefetchSlot]; if (nextChunkToDownloadIndex < chunkInfos.Count && nextChunkToConsumeIndex > 0) { - SFReusableChunk reusableChunk = chunkDatas[nextChunkToDownloadIndex % prefetchSlot]; + BaseResultChunk reusableChunk = chunkDatas[nextChunkToDownloadIndex % prefetchSlot]; reusableChunk.Reset(chunkInfos[nextChunkToDownloadIndex], nextChunkToDownloadIndex); taskQueues[nextChunkToDownloadIndex % prefetchSlot] = DownloadChunkAsync(new DownloadContextV3() @@ -118,21 +117,25 @@ public async Task GetNextChunkAsync() cancellationToken = externalCancellationToken }); nextChunkToDownloadIndex++; - } + // in case of one slot we need to return the chunk already downloaded + if (prefetchSlot == 1) + { + chunk = taskQueues[0]; + } + } nextChunkToConsumeIndex++; return await chunk; } else { - return await Task.FromResult(null); + return await Task.FromResult(null); } } - private async Task DownloadChunkAsync(DownloadContextV3 downloadContext) + private async Task DownloadChunkAsync(DownloadContextV3 downloadContext) { - //logger.Info($"Start downloading chunk #{downloadContext.chunkIndex}"); - SFReusableChunk chunk = downloadContext.chunk; + BaseResultChunk chunk = downloadContext.chunk; int backOffInSec = 1; bool retry = false; int retryCount = 0; @@ -208,29 +211,20 @@ private async Task DownloadChunkAsync(DownloadContextV3 downloadCo } } } while (retry); - logger.Info($"Succeed downloading chunk #{chunk.chunkIndexToDownload}"); + logger.Info($"Succeed downloading chunk #{chunk.ChunkIndex}"); return chunk; } - - /// - /// Content from s3 in format of - /// ["val1", "val2", null, ...], - /// ["val3", "val4", null, ...], - /// ... - /// To parse it as a json, we need to preappend '[' and append ']' to the stream - /// - /// - /// - private async Task ParseStreamIntoChunk(Stream content, IResultChunk resultChunk) + + private async Task ParseStreamIntoChunk(Stream content, BaseResultChunk resultChunk) { - IChunkParser parser = ChunkParserFactory.Instance.GetParser(content); + IChunkParser parser = ChunkParserFactory.Instance.GetParser(resultChunk.Format, content); await parser.ParseChunk(resultChunk); } } class DownloadContextV3 { - public SFReusableChunk chunk { get; set; } + public BaseResultChunk chunk { get; set; } public string qrmk { get; set; } diff --git a/Snowflake.Data/Core/SFChunkDownloaderV2.cs b/Snowflake.Data/Core/SFChunkDownloaderV2.cs index be9f2b77f..b1400e271 100755 --- a/Snowflake.Data/Core/SFChunkDownloaderV2.cs +++ b/Snowflake.Data/Core/SFChunkDownloaderV2.cs @@ -55,8 +55,8 @@ public SFChunkDownloaderV2(int colCount, ListchunkInfos, stri FillDownloads(); } - private BlockingCollection>> _downloadTasks; - private ConcurrentQueue>> _downloadQueue; + private BlockingCollection>> _downloadTasks; + private ConcurrentQueue>> _downloadQueue; private void RunDownloads() { @@ -79,11 +79,11 @@ private void RunDownloads() private void FillDownloads() { - _downloadTasks = new BlockingCollection>>(); + _downloadTasks = new BlockingCollection>>(); foreach (var c in chunks) { - var t = new Lazy>(() => DownloadChunkAsync(new DownloadContextV2() + var t = new Lazy>(() => DownloadChunkAsync(new DownloadContextV2() { chunk = c, chunkIndex = c.ChunkIndex, @@ -97,18 +97,18 @@ private void FillDownloads() _downloadTasks.CompleteAdding(); - _downloadQueue = new ConcurrentQueue>>(_downloadTasks); + _downloadQueue = new ConcurrentQueue>>(_downloadTasks); for (var i = 0; i < prefetchSlot && i < chunks.Count; i++) Task.Run(new Action(RunDownloads)); } - public Task GetNextChunkAsync() + public Task GetNextChunkAsync() { if (_downloadTasks.IsAddingCompleted) { - return Task.FromResult(null); + return Task.FromResult(null); } else { @@ -116,16 +116,14 @@ public Task GetNextChunkAsync() } } - private async Task DownloadChunkAsync(DownloadContextV2 downloadContext) + private async Task DownloadChunkAsync(DownloadContextV2 downloadContext) { logger.Info($"Start downloading chunk #{downloadContext.chunkIndex+1}"); - SFResultChunk chunk = downloadContext.chunk; - - chunk.downloadState = DownloadState.IN_PROGRESS; + BaseResultChunk chunk = downloadContext.chunk; S3DownloadRequest downloadRequest = new S3DownloadRequest() { - Url = new UriBuilder(chunk.url).Uri, + Url = new UriBuilder(chunk.Url).Uri, qrmk = downloadContext.qrmk, // s3 download request timeout to one hour RestTimeout = TimeSpan.FromHours(1), @@ -146,10 +144,9 @@ private async Task DownloadChunkAsync(DownloadContextV2 downloadCo } } - parseStreamIntoChunk(stream, chunk); + ParseStreamIntoChunk(stream, chunk); } - chunk.downloadState = DownloadState.SUCCESS; logger.Info($"Succeed downloading chunk #{downloadContext.chunkIndex+1}"); return chunk; @@ -165,21 +162,21 @@ private async Task DownloadChunkAsync(DownloadContextV2 downloadCo /// /// /// - private static void parseStreamIntoChunk(Stream content, SFResultChunk resultChunk) + private static void ParseStreamIntoChunk(Stream content, BaseResultChunk resultChunk) { Stream openBracket = new MemoryStream(Encoding.UTF8.GetBytes("[")); Stream closeBracket = new MemoryStream(Encoding.UTF8.GetBytes("]")); Stream concatStream = new ConcatenatedStream(new Stream[3] { openBracket, content, closeBracket}); - IChunkParser parser = ChunkParserFactory.Instance.GetParser(concatStream); + IChunkParser parser = ChunkParserFactory.Instance.GetParser(resultChunk.Format, concatStream); parser.ParseChunk(resultChunk); } } class DownloadContextV2 { - public SFResultChunk chunk { get; set; } + public BaseResultChunk chunk { get; set; } public int chunkIndex { get; set; } diff --git a/Snowflake.Data/Core/SFResultChunk.cs b/Snowflake.Data/Core/SFResultChunk.cs index 7ff77bebb..05aa27b63 100755 --- a/Snowflake.Data/Core/SFResultChunk.cs +++ b/Snowflake.Data/Core/SFResultChunk.cs @@ -6,70 +6,59 @@ namespace Snowflake.Data.Core { - internal class SFResultChunk : IResultChunk + internal class SFResultChunk : BaseResultChunk { - public string[,] rowSet { get; set; } + internal override ResultFormat Format => ResultFormat.JSON; - public int rowCount { get; set; } - - public int colCount { get; set; } - - public string url { get; set; } - - public DownloadState downloadState { get; set; } - public int ChunkIndex { get; } - - public readonly object syncPrimitive; + private int _currentRowIndex = -1; public SFResultChunk(string[,] rowSet) { - this.rowSet = rowSet; - this.rowCount = rowSet.GetLength(0); - this.colCount = rowSet.GetLength(1); - this.downloadState = DownloadState.NOT_STARTED; + RowSet = rowSet; + RowCount = rowSet.GetLength(0); + ColumnCount = rowSet.GetLength(1); } - public SFResultChunk(string url, int rowCount, int colCount, int index) + public SFResultChunk(string url, int rowCount, int columnCount, int index) { - this.rowCount = rowCount; - this.colCount = colCount; - this.url = url; + RowCount = rowCount; + ColumnCount = columnCount; + Url = url; ChunkIndex = index; - syncPrimitive = new object(); - this.downloadState = DownloadState.NOT_STARTED; } - public UTF8Buffer ExtractCell(int rowIndex, int columnIndex) + public override UTF8Buffer ExtractCell(int rowIndex, int columnIndex) + { + _currentRowIndex = rowIndex; + return ExtractCell(columnIndex); + } + + public override UTF8Buffer ExtractCell(int columnIndex) { // Convert string to UTF8Buffer. This makes this method a little slower, but this class is not used for large result sets - string s = rowSet[rowIndex, columnIndex]; + string s = RowSet[_currentRowIndex, columnIndex]; if (s == null) return null; byte[] b = Encoding.UTF8.GetBytes(s); return new UTF8Buffer(b); } - public void addValue(string val, int rowCount, int colCount) + internal override bool Next() { - rowSet[rowCount, colCount] = val; + _currentRowIndex += 1; + return _currentRowIndex < RowCount; } - public int GetRowCount() + internal override bool Rewind() { - return rowCount; + _currentRowIndex -= 1; + return _currentRowIndex >= 0; } - public int GetChunkIndex() + internal override void Reset(ExecResponseChunk chunkInfo, int chunkIndex) { - return ChunkIndex; + base.Reset(chunkInfo, chunkIndex); + _currentRowIndex = -1; } } - - public enum DownloadState - { - NOT_STARTED, - IN_PROGRESS, - SUCCESS, - FAILURE - } } diff --git a/Snowflake.Data/Core/SFResultSet.cs b/Snowflake.Data/Core/SFResultSet.cs index 0ce066d63..2e20453ba 100755 --- a/Snowflake.Data/Core/SFResultSet.cs +++ b/Snowflake.Data/Core/SFResultSet.cs @@ -12,28 +12,22 @@ namespace Snowflake.Data.Core { class SFResultSet : SFBaseResultSet { - private static readonly SFLogger Logger = SFLoggerFactory.GetLogger(); + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - private int _currentChunkRowIdx; - - private int _currentChunkRowCount; - private readonly int _totalChunkCount; private readonly IChunkDownloader _chunkDownloader; - private IResultChunk _currentChunk; + private BaseResultChunk _currentChunk; public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() { try { columnCount = responseData.rowType.Count; - _currentChunkRowIdx = -1; - _currentChunkRowCount = responseData.rowSet.GetLength(0); this.sfStatement = sfStatement; - updateSessionStatus(responseData); + UpdateSessionStatus(responseData); if (responseData.chunks != null) { @@ -53,7 +47,7 @@ public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, } catch(System.Exception ex) { - Logger.Error("Result set error queryId="+responseData.queryId, ex); + s_logger.Error("Result set error queryId="+responseData.queryId, ex); throw; } } @@ -69,9 +63,9 @@ public enum PutGetResponseRowTypeInfo { ErrorDetails = 7 } - public void initializePutGetRowType(List rowType) + public void InitializePutGetRowType(List rowType) { - foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) + foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) { rowType.Add(new ExecResponseRowType() { @@ -84,11 +78,9 @@ public void initializePutGetRowType(List rowType) public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() { responseData.rowType = new List(); - initializePutGetRowType(responseData.rowType); + InitializePutGetRowType(responseData.rowType); columnCount = responseData.rowType.Count; - _currentChunkRowIdx = -1; - _currentChunkRowCount = responseData.rowSet.GetLength(0); this.sfStatement = sfStatement; @@ -102,75 +94,54 @@ public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, Can queryId = responseData.queryId; } - internal void resetChunkInfo(IResultChunk nextChunk) + internal void ResetChunkInfo(BaseResultChunk nextChunk) { - Logger.Debug($"Recieved chunk #{nextChunk.GetChunkIndex() + 1} of {_totalChunkCount}"); - if (_currentChunk is SFResultChunk) - { - ((SFResultChunk)_currentChunk).rowSet = null; - } + s_logger.Debug($"Received chunk #{nextChunk.ChunkIndex + 1} of {_totalChunkCount}"); + _currentChunk.RowSet = null; _currentChunk = nextChunk; - _currentChunkRowIdx = 0; - _currentChunkRowCount = _currentChunk.GetRowCount(); } internal override async Task NextAsync() { - if (isClosed) - { - throw new SnowflakeDbException(SFError.DATA_READER_ALREADY_CLOSED); - } + ThrowIfClosed(); - _currentChunkRowIdx++; - if (_currentChunkRowIdx < _currentChunkRowCount) - { + if (_currentChunk.Next()) return true; - } if (_chunkDownloader != null) { // GetNextChunk could be blocked if download result is not done yet. // So put this piece of code in a seperate task - Logger.Info("Get next chunk from chunk downloader"); - IResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); + s_logger.Debug("Get next chunk from chunk downloader"); + BaseResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); if (nextChunk != null) { - resetChunkInfo(nextChunk); - return true; - } - else - { - return false; + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); } } - return false; + return false; } internal override bool Next() { - if (isClosed) - { - throw new SnowflakeDbException(SFError.DATA_READER_ALREADY_CLOSED); - } + ThrowIfClosed(); - _currentChunkRowIdx++; - if (_currentChunkRowIdx < _currentChunkRowCount) - { + if (_currentChunk.Next()) return true; - } if (_chunkDownloader != null) { - Logger.Info("Get next chunk from chunk downloader"); - IResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; + s_logger.Debug("Get next chunk from chunk downloader"); + BaseResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; if (nextChunk != null) { - resetChunkInfo(nextChunk); - return true; + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); } } - return false; + return false; } internal override bool NextResult() @@ -185,12 +156,9 @@ internal override async Task NextResultAsync(CancellationToken cancellatio internal override bool HasRows() { - if (isClosed) - { - return false; - } + ThrowIfClosed(); - return _currentChunkRowCount > 0 || _totalChunkCount > 0; + return _currentChunk.RowCount > 0 || _totalChunkCount > 0; } /// @@ -199,39 +167,24 @@ internal override bool HasRows() /// True if it works, false otherwise. internal override bool Rewind() { - if (isClosed) - { - throw new SnowflakeDbException(SFError.DATA_READER_ALREADY_CLOSED); - } - - if (_currentChunkRowIdx >= 0) - { - _currentChunkRowIdx--; - if (_currentChunkRowIdx >= _currentChunkRowCount) - { - return true; - } - } + ThrowIfClosed(); - return false; + return _currentChunk.Rewind(); } internal override UTF8Buffer getObjectInternal(int columnIndex) { - if (isClosed) - { - throw new SnowflakeDbException(SFError.DATA_READER_ALREADY_CLOSED); - } + ThrowIfClosed(); if (columnIndex < 0 || columnIndex >= columnCount) { throw new SnowflakeDbException(SFError.COLUMN_INDEX_OUT_OF_BOUND, columnIndex); } - return _currentChunk.ExtractCell(_currentChunkRowIdx, columnIndex); + return _currentChunk.ExtractCell(columnIndex); } - private void updateSessionStatus(QueryExecResponseData responseData) + private void UpdateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); diff --git a/Snowflake.Data/Core/SFReusableChunk.cs b/Snowflake.Data/Core/SFReusableChunk.cs index c1362aba3..8e36907a0 100755 --- a/Snowflake.Data/Core/SFReusableChunk.cs +++ b/Snowflake.Data/Core/SFReusableChunk.cs @@ -8,51 +8,41 @@ namespace Snowflake.Data.Core { - class SFReusableChunk : IResultChunk + class SFReusableChunk : BaseResultChunk { - - public int RowCount { get; set; } - - public int ColCount { get; set; } - - public string Url { get; set; } - - public int chunkIndexToDownload { get; set; } - + internal override ResultFormat Format => ResultFormat.JSON; + private readonly BlockResultData data; - internal SFReusableChunk(int colCount) + private int _currentRowIndex = -1; + + internal SFReusableChunk(int columnCount) { - ColCount = colCount; + ColumnCount = columnCount; data = new BlockResultData(); } - internal void Reset(ExecResponseChunk chunkInfo, int chunkIndex) + internal override void Reset(ExecResponseChunk chunkInfo, int chunkIndex) { - this.RowCount = chunkInfo.rowCount; - this.Url = chunkInfo.url; - this.chunkIndexToDownload = chunkIndex; - data.Reset(this.RowCount, this.ColCount, chunkInfo.uncompressedSize); + base.Reset(chunkInfo, chunkIndex); + _currentRowIndex = -1; + data.Reset(RowCount, ColumnCount, chunkInfo.uncompressedSize); } - internal void ResetForRetry() + internal override void ResetForRetry() { data.ResetForRetry(); } - public int GetRowCount() + public override UTF8Buffer ExtractCell(int rowIndex, int columnIndex) { - return RowCount; + _currentRowIndex = rowIndex; + return ExtractCell(columnIndex); } - public int GetChunkIndex() + public override UTF8Buffer ExtractCell(int columnIndex) { - return chunkIndexToDownload; - } - - public UTF8Buffer ExtractCell(int rowIndex, int columnIndex) - { - return data.get(rowIndex * ColCount + columnIndex); + return data.get(_currentRowIndex * ColumnCount + columnIndex); } public void AddCell(string val) @@ -66,6 +56,18 @@ public void AddCell(byte[] bytes, int length) data.add(bytes, length); } + internal override bool Next() + { + _currentRowIndex += 1; + return _currentRowIndex < RowCount; + } + + internal override bool Rewind() + { + _currentRowIndex -= 1; + return _currentRowIndex >= 0; + } + private class BlockResultData { private static readonly int NULL_VALUE = -100;