Skip to content

Commit

Permalink
Add support for dictionaries to Flight implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
CurtHagenlocher committed Dec 12, 2023
1 parent 595b37c commit 6512096
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 35 deletions.
63 changes: 49 additions & 14 deletions csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
// limitations under the License.

using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flatbuf;
using Apache.Arrow.Flight.Protocol;
using Apache.Arrow.Ipc;
using Google.FlatBuffers;
using Google.Protobuf;
Expand All @@ -36,6 +33,7 @@ internal class FlightDataStream : ArrowStreamWriter
private readonly FlightDescriptor _flightDescriptor;
private readonly IAsyncStreamWriter<Protocol.FlightData> _clientStreamWriter;
private Protocol.FlightData _currentFlightData;
private ByteString _currentAppMetadata;

public FlightDataStream(IAsyncStreamWriter<Protocol.FlightData> clientStreamWriter, FlightDescriptor flightDescriptor, Schema schema)
: base(new MemoryStream(), schema)
Expand Down Expand Up @@ -66,29 +64,66 @@ private void ResetStream()
this.BaseStream.SetLength(0);
}

private void ResetFlightData()
{
_currentFlightData = new Protocol.FlightData();
}

private void AddMetadata()
{
if (_currentAppMetadata != null)
{
_currentFlightData.AppMetadata = _currentAppMetadata;
}
}

private async Task SetFlightDataBodyFromBaseStreamAsync(CancellationToken cancellationToken)
{
BaseStream.Position = 0;
var body = await ByteString.FromStreamAsync(BaseStream, cancellationToken).ConfigureAwait(false);
_currentFlightData.DataBody = body;
}

private async Task WriteFlightDataAsync()
{
await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false);
}

public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata)
{
_currentAppMetadata = applicationMetadata;
if (!HasWrittenSchema)
{
await SendSchema().ConfigureAwait(false);
}
ResetStream();
ResetFlightData();

_currentFlightData = new Protocol.FlightData();
await WriteRecordBatchAsync(recordBatch).ConfigureAwait(false);
}

if(applicationMetadata != null)
{
_currentFlightData.AppMetadata = applicationMetadata;
}
public override async Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)
{
await WriteRecordBatchInternalAsync(recordBatch, cancellationToken);

await WriteRecordBatchInternalAsync(recordBatch).ConfigureAwait(false);
// Consume the MemoryStream and write to the flight stream
await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false);
AddMetadata();
await WriteFlightDataAsync().ConfigureAwait(false);

//Reset stream position
this.BaseStream.Position = 0;
var bodyData = await ByteString.FromStreamAsync(this.BaseStream).ConfigureAwait(false);
HasWrittenDictionaryBatch = false; // force the dictionary to be sent again with the next batch
}

_currentFlightData.DataBody = bodyData;
await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false);
private protected override async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken)
{
await base.WriteDictionariesAsync(dictionaryMemo, cancellationToken).ConfigureAwait(false);

// Consume the MemoryStream and write to the flight stream
await SetFlightDataBodyFromBaseStreamAsync(cancellationToken).ConfigureAwait(false);
await WriteFlightDataAsync().ConfigureAwait(false);
// Reset the stream for the next dictionary or record batch
ResetStream();
ResetFlightData();
}

private protected override ValueTask<long> WriteMessageAsync<T>(MessageHeader headerType, Offset<T> headerOffset, int bodyLength, CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

using System;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Apache.Arrow.Ipc;
using Google.FlatBuffers;

Expand Down Expand Up @@ -50,10 +48,8 @@ public static Schema DecodeSchema(ReadOnlyMemory<byte> buffer)
return schema;
}

internal static Schema DecodeSchema(ByteBuffer schemaBuffer)
internal static Schema DecodeSchema(ByteBuffer schemaBuffer, ref DictionaryMemo dictionaryMemo)
{
//DictionaryBatch not supported for now
DictionaryMemo dictionaryMemo = null;
var schema = MessageSerializer.GetSchema(ArrowReaderImplementation.ReadMessage<Flatbuf.Schema>(schemaBuffer), ref dictionaryMemo);
return schema;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public async ValueTask<Schema> ReadSchema()
switch (message.HeaderType)
{
case MessageHeader.Schema:
Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer);
Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer, ref _dictionaryMemo);
break;
default:
throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}");
Expand All @@ -103,8 +103,10 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
{
await ReadSchema().ConfigureAwait(false);
}
var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false);
if (moveNextResult)

// Keep reading dictionary batches until we get a record batch
var keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false);
while (keepGoing)
{
//AppMetadata will never be null, but length 0 if empty
//Those are skipped
Expand All @@ -121,8 +123,17 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
case MessageHeader.RecordBatch:
var body = _flightDataStream.Current.DataBody.Memory;
return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null);
case MessageHeader.DictionaryBatch:
var dictionaryBody = _flightDataStream.Current.DataBody.Memory;
CreateArrowObjectFromMessage(message, CreateByteBuffer(dictionaryBody.Slice(0, (int)message.BodyLength)), null);
keepGoing = await _flightDataStream.MoveNext().ConfigureAwait(false);
if (!keepGoing)
{
throw new InvalidOperationException("Flight Data Stream ended after reading dictionaries");
}
break;
default:
throw new NotImplementedException();
throw new NotImplementedException($"Message type {message.HeaderType} is not implemented.");
}
}
return null;
Expand Down
13 changes: 6 additions & 7 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public void Visit(IArrowArray array)

protected bool HasWrittenSchema { get; set; }

private bool HasWrittenDictionaryBatch { get; set; }
protected bool HasWrittenDictionaryBatch { get; set; }

private bool HasWrittenStart { get; set; }

Expand Down Expand Up @@ -323,7 +323,7 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch)

if (!HasWrittenDictionaryBatch)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo);
WriteDictionaries(_dictionaryMemo);
HasWrittenDictionaryBatch = true;
}
Expand Down Expand Up @@ -362,7 +362,7 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat

if (!HasWrittenDictionaryBatch)
{
DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
DictionaryCollector.Collect(Schema, recordBatch, ref _dictionaryMemo);
await WriteDictionariesAsync(_dictionaryMemo, cancellationToken).ConfigureAwait(false);
HasWrittenDictionaryBatch = true;
}
Expand Down Expand Up @@ -505,7 +505,7 @@ private protected virtual void FinishedWritingDictionary(long bodyLength, long m
{
}

private protected void WriteDictionaries(DictionaryMemo dictionaryMemo)
private protected virtual void WriteDictionaries(DictionaryMemo dictionaryMemo)
{
int fieldCount = dictionaryMemo?.DictionaryCount ?? 0;
for (int i = 0; i < fieldCount; i++)
Expand All @@ -529,7 +529,7 @@ private protected void WriteDictionary(long id, IArrowType valueType, IArrowArra
FinishedWritingDictionary(bufferLength, metadataLength);
}

private protected async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken)
private protected virtual async Task WriteDictionariesAsync(DictionaryMemo dictionaryMemo, CancellationToken cancellationToken)
{
int fieldCount = dictionaryMemo?.DictionaryCount ?? 0;
for (int i = 0; i < fieldCount; i++)
Expand Down Expand Up @@ -960,9 +960,8 @@ public virtual void Dispose()

internal static class DictionaryCollector
{
internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo)
internal static void Collect(Schema schema, RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo)
{
Schema schema = recordBatch.Schema;
for (int i = 0; i < schema.FieldsList.Count; i++)
{
Field field = schema.GetFieldByIndex(i);
Expand Down
15 changes: 10 additions & 5 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using Apache.Arrow.Flight.Client;
using Apache.Arrow.Flight.TestWeb;
using Apache.Arrow.Tests;
using Apache.Arrow.Types;
using Google.Protobuf;
using Grpc.Core.Utils;
using Xunit;
Expand Down Expand Up @@ -52,6 +53,10 @@ private RecordBatch CreateTestBatch(int startValue, int length)
builder.Append(startValue + i);
}
batchBuilder.Append("test", true, builder.Build());
var keys = new UInt16Array.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => (ushort)i)).Build();
var dictionary = new StringArray.Builder().AppendRange(Enumerable.Range(startValue, length).Select(i => i.ToString())).Build();
var dictArray = new DictionaryArray(new DictionaryType(UInt16Type.Default, StringType.Default, false), keys, dictionary);
batchBuilder.Append("dict", true, dictArray);
return batchBuilder.Build();
}

Expand All @@ -70,7 +75,7 @@ private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params R

var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress());

foreach(var batch in batches)
foreach (var batch in batches)
{
flightHolder.AddBatch(batch);
}
Expand Down Expand Up @@ -187,8 +192,8 @@ public async Task TestGetFlightMetadata()

var getStream = _flightClient.GetStream(endpoint.Ticket);

List<ByteString> actualMetadata = new List<ByteString>();
while(await getStream.ResponseStream.MoveNext(default))
List<ByteString> actualMetadata = new List<ByteString>();
while (await getStream.ResponseStream.MoveNext(default))
{
actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata);
}
Expand Down Expand Up @@ -277,7 +282,7 @@ public async Task TestListFlights()

var actualFlights = await listFlightStream.ResponseStream.ToListAsync();

for(int i = 0; i < expectedFlightInfo.Count; i++)
for (int i = 0; i < expectedFlightInfo.Count; i++)
{
FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]);
}
Expand Down Expand Up @@ -329,7 +334,7 @@ public async Task TestGetBatchesWithAsyncEnumerable()


List<RecordBatch> resultList = new List<RecordBatch>();
await foreach(var recordBatch in getStream.ResponseStream)
await foreach (var recordBatch in getStream.ResponseStream)
{
resultList.Add(recordBatch);
}
Expand Down

0 comments on commit 6512096

Please sign in to comment.