diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs index 72c1551be2917..e755b4a26f621 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs @@ -14,13 +14,10 @@ // limitations under the License. using System; -using System.Collections.Generic; using System.IO; -using System.Text; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flatbuf; -using Apache.Arrow.Flight.Protocol; using Apache.Arrow.Ipc; using Google.FlatBuffers; using Google.Protobuf; @@ -36,6 +33,7 @@ internal class FlightDataStream : ArrowStreamWriter private readonly FlightDescriptor _flightDescriptor; private readonly IAsyncStreamWriter _clientStreamWriter; private Protocol.FlightData _currentFlightData; + private ByteString _currentAppMetadata; public FlightDataStream(IAsyncStreamWriter clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema) : base(new MemoryStream(), schema) @@ -66,29 +64,66 @@ private void ResetStream() this.BaseStream.SetLength(0); } + private void ResetFlightData() + { + _currentFlightData = new Protocol.FlightData(); + } + + private void AddMetadata() + { + if (_currentAppMetadata != null) + { + _currentFlightData.AppMetadata = _currentAppMetadata; + } + } + + private async Task SetFlightDataBodyFromBaseStreamAsync(CancellationToken cancellationToken) + { + BaseStream.Position = 0; + var body = await ByteString.FromStreamAsync(BaseStream, cancellationToken).ConfigureAwait(false); + _currentFlightData.DataBody = body; + } + + private async Task WriteFlightDataAsync() + { + await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false); + } + public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata) { + _currentAppMetadata = applicationMetadata; if (!HasWrittenSchema) { await SendSchema().ConfigureAwait(false); } ResetStream(); + ResetFlightData(); - _currentFlightData = new Protocol.FlightData(); + await WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); + } - if(applicationMetadata != null) - { - _currentFlightData.AppMetadata = applicationMetadata; - } + public override async Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) + { + await WriteRecordBatchInternalAsync(recordBatch, cancellationToken); - await WriteRecordBatchInternalAsync(recordBatch).ConfigureAwait(false); + // Consume the MemoryStream and write to the flight stream + await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false); + AddMetadata(); + await WriteFlightDataAsync().ConfigureAwait(false); - //Reset stream position - this.BaseStream.Position = 0; - var bodyData = await ByteString.FromStreamAsync(this.BaseStream).ConfigureAwait(false); + HasWrittenDictionaryBatch = false; // force the dictionary to be sent again with the next batch + } - _currentFlightData.DataBody = bodyData; - await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false); + private protected override async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken) + { + await base.WriteDictionariesAsync(dictionaryMemo, cancellationToken).ConfigureAwait(false); + + // Consume the MemoryStream and write to the flight stream + await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false); + await WriteFlightDataAsync().ConfigureAwait(false); + // Reset the stream for the next dictionary or record batch + ResetStream(); + ResetFlightData(); } private protected override ValueTask WriteMessageAsync(MessageHeader headerType, Offset headerOffset, int bodyLength, CancellationToken cancellationToken) diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs index 9df28b5033c06..67afeaf7b034b 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightMessageSerializer.cs @@ -15,9 +15,7 @@ using System; using System.Buffers.Binary; -using System.Collections.Generic; using System.IO; -using System.Text; using Apache.Arrow.Ipc; using Google.FlatBuffers; @@ -50,10 +48,8 @@ public static Schema DecodeSchema(ReadOnlyMemory buffer) return schema; } - internal static Schema DecodeSchema(ByteBuffer schemaBuffer) + internal static Schema DecodeSchema(ByteBuffer schemaBuffer, ref DictionaryMemo dictionaryMemo) { - //DictionaryBatch not supported for now - DictionaryMemo dictionaryMemo = null; var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage(schemaBuffer), ref dictionaryMemo); return schema; } diff --git a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs index be844ea58e404..44a025364b601 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs @@ -87,7 +87,7 @@ public async ValueTask ReadSchema() switch (message.HeaderType) { case MessageHeader.Schema: - Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); + Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer, ref _dictionaryMemo); break; default: throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}"); @@ -103,8 +103,10 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati { await ReadSchema().ConfigureAwait(false); } - var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); - if (moveNextResult) + + // Keep reading dictionary batches until we get a record batch + var keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false); + while (keepGoing) { //AppMetadata will never be null, but length 0 if empty //Those are skipped @@ -121,8 +123,17 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati case MessageHeader.RecordBatch: var body = _flightDataStream.Current.DataBody.Memory; return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null); + case MessageHeader.DictionaryBatch: + var dictionaryBody = _flightDataStream.Current.DataBody.Memory; + CreateArrowObjectFromMessage(message, CreateByteBuffer(dictionaryBody.Slice(0, (int)message.BodyLength)), null); + keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false); + if (!keepGoing) + { + throw new InvalidOperationException("Flight Data Stream ended after reading dictionaries"); + } + break; default: - throw new NotImplementedException(); + throw new NotImplementedException($"Message type {message.HeaderType} is not implemented."); } } return null; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 5f490019b2133..c982fc3529a07 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -232,7 +232,7 @@ public void Visit(IArrowArray array) protected bool HasWrittenSchema { get; set; } - private bool HasWrittenDictionaryBatch { get; set; } + protected bool HasWrittenDictionaryBatch { get; set; } private bool HasWrittenStart { get; set; } @@ -323,7 +323,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); + DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo); WriteDictionaries(_dictionaryMemo); HasWrittenDictionaryBatch = true; } @@ -362,7 +362,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat if (!HasWrittenDictionaryBatch) { - DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo); + DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo); await WriteDictionariesAsync(_dictionaryMemo, cancellationToken).ConfigureAwait(false); HasWrittenDictionaryBatch = true; } @@ -505,7 +505,7 @@ private protected virtual void FinishedWritingDictionary(long bodyLength, long m { } - private protected void WriteDictionaries(DictionaryMemo dictionaryMemo) + private protected virtual void WriteDictionaries(DictionaryMemo dictionaryMemo) { int fieldCount = dictionaryMemo?.DictionaryCount ?? 0; for (int i = 0; i < fieldCount; i++) @@ -529,7 +529,7 @@ private protected void WriteDictionary(long id, IArrowType valueType, IArrowArra FinishedWritingDictionary(bufferLength, metadataLength); } - private protected async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken) + private protected virtual async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken) { int fieldCount = dictionaryMemo?.DictionaryCount ?? 0; for (int i = 0; i < fieldCount; i++) @@ -960,9 +960,8 @@ public virtual void Dispose() internal static class DictionaryCollector { - internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo) + internal static void Collect(Schema schema, RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo) { - Schema schema = recordBatch.Schema; for (int i = 0; i < schema.FieldsList.Count; i++) { Field field = schema.GetFieldByIndex(i); diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 267fe4e4b606d..33a9d3739b080 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -20,6 +20,7 @@ using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.TestWeb; using Apache.Arrow.Tests; +using Apache.Arrow.Types; using Google.Protobuf; using Grpc.Core.Utils; using Xunit; @@ -52,6 +53,10 @@ private RecordBatch CreateTestBatch(int startValue, int length) builder.Append(startValue + i); } batchBuilder.Append("test", true, builder.Build()); + var keys = new UInt16Array.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => (ushort)i)).Build(); + var dictionary = new StringArray.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => i.ToString())).Build(); + var dictArray = new DictionaryArray(new DictionaryType(UInt16Type.Default, StringType.Default, false), keys, dictionary); + batchBuilder.Append("dict", true, dictArray); return batchBuilder.Build(); } @@ -70,7 +75,7 @@ private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params R var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress()); - foreach(var batch in batches) + foreach (var batch in batches) { flightHolder.AddBatch(batch); } @@ -187,8 +192,8 @@ public async Task TestGetFlightMetadata() var getStream = _flightClient.GetStream(endpoint.Ticket); - List actualMetadata = new List(); - while(await getStream.ResponseStream.MoveNext(default)) + List actualMetadata = new List(); + while (await getStream.ResponseStream.MoveNext(default)) { actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata); } @@ -277,7 +282,7 @@ public async Task TestListFlights() var actualFlights = await listFlightStream.ResponseStream.ToListAsync(); - for(int i = 0; i < expectedFlightInfo.Count; i++) + for (int i = 0; i < expectedFlightInfo.Count; i++) { FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]); } @@ -329,7 +334,7 @@ public async Task TestGetBatchesWithAsyncEnumerable() List resultList = new List(); - await foreach(var recordBatch in getStream.ResponseStream) + await foreach (var recordBatch in getStream.ResponseStream) { resultList.Add(recordBatch); }