From cc771a013362248269b75e054c2fed9c3d0f352a Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Mon, 25 Mar 2024 07:59:38 -0700 Subject: [PATCH] GH-40634: [C#] ArrowStreamReader should not be null (#40765) ### What changes are included in this PR? Small refactoring in the IPC reader implementation classes of how the schema is read in order to support getting the schema asynchronously through ArrowStreamReader and avoiding the case where ArrowStreamReader.Schema returns null because no record batches have yet been read. ### Are these changes tested? Yes. ### Are there any user-facing changes? A new method ArrowStreamReader.GetSchema has been added to allow the schema to be gotten asynchronously. Closes #40634 * GitHub Issue: #40634 Authored-by: Curt Hagenlocher Signed-off-by: Curt Hagenlocher --- .../FlightRecordBatchStreamReader.cs | 4 +-- .../RecordBatchReaderImplementation.cs | 27 ++++++++++++++----- .../Ipc/ArrowFileReaderImplementation.cs | 6 ++--- .../Ipc/ArrowMemoryReaderImplementation.cs | 11 ++++++-- .../Ipc/ArrowReaderImplementation.cs | 19 +++++++++++-- .../src/Apache.Arrow/Ipc/ArrowStreamReader.cs | 12 +++++++++ .../Ipc/ArrowStreamReaderImplementation.cs | 8 +++--- .../Apache.Arrow.Tests/ArrowReaderVerifier.cs | 3 +++ .../ArrowStreamReaderTests.cs | 2 ++ 9 files changed, 72 insertions(+), 20 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs index d21fb25f5c946..7400ec15e54d6 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs @@ -45,12 +45,12 @@ private protected FlightRecordBatchStreamReader(IAsyncStreamReader Schema => _arrowReaderImplementation.ReadSchema(); + public ValueTask Schema => _arrowReaderImplementation.GetSchemaAsync(); internal ValueTask GetFlightDescriptor() { return _arrowReaderImplementation.ReadFlightDescriptor(); - } + } /// /// Get the application metadata from the latest received record batch diff --git a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs index be844ea58e404..99876bf769dc7 100644 --- a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs +++ b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs @@ -48,19 +48,33 @@ public async ValueTask ReadFlightDescriptor() { if (!HasReadSchema) { - await ReadSchema().ConfigureAwait(false); + await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false); } return _flightDescriptor; } - public async ValueTask ReadSchema() + public async ValueTask GetSchemaAsync() + { + if (!HasReadSchema) + { + await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false); + } + return _schema; + } + + public override void ReadSchema() + { + ReadSchemaAsync(CancellationToken.None).AsTask().Wait(); + } + + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken) { if (HasReadSchema) { - return Schema; + return; } - var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); + var moveNextResult = await _flightDataStream.MoveNext(cancellationToken).ConfigureAwait(false); if (!moveNextResult) { @@ -87,12 +101,11 @@ public async ValueTask ReadSchema() switch (message.HeaderType) { case MessageHeader.Schema: - Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); + _schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer); break; default: throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}"); } - return Schema; } public override async ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) @@ -101,7 +114,7 @@ public override async ValueTask ReadNextRecordBatchAsync(Cancellati if (!HasReadSchema) { - await ReadSchema().ConfigureAwait(false); + await ReadSchemaAsync(cancellationToken).ConfigureAwait(false); } var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false); if (moveNextResult) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs index 02f36b079349b..4b7c5f914c402 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs @@ -52,7 +52,7 @@ public async ValueTask RecordBatchCountAsync(CancellationToken cancellation return _footer.RecordBatchCount; } - protected override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { @@ -85,7 +85,7 @@ protected override async ValueTask ReadSchemaAsync(CancellationToken cancellatio } } - protected override void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -139,7 +139,7 @@ private void ReadSchema(Memory buffer) // Deserialize the footer from the footer flatbuffer _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), ref _dictionaryMemo); - Schema = _footer.Schema; + _schema = _footer.Schema; } public async ValueTask ReadRecordBatchAsync(int index, CancellationToken cancellationToken) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs index 6e2336a591bf1..842c56823d07f 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs @@ -33,6 +33,13 @@ public ArrowMemoryReaderImplementation(ReadOnlyMemory buffer, ICompression _buffer = buffer; } + public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + ReadSchema(); + return default; + } + public override ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -93,7 +100,7 @@ public override RecordBatch ReadNextRecordBatch() return batch; } - private void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -117,7 +124,7 @@ private void ReadSchema() } ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition)); - Schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemaBuffer), ref _dictionaryMemo); _bufferPosition += schemaMessageLength; } } diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index eb7349a570786..4e273dbde5690 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -30,13 +30,25 @@ namespace Apache.Arrow.Ipc { internal abstract class ArrowReaderImplementation : IDisposable { - public Schema Schema { get; protected set; } - protected bool HasReadSchema => Schema != null; + public Schema Schema + { + get + { + if (!HasReadSchema) + { + ReadSchema(); + } + return _schema; + } + } + + protected internal bool HasReadSchema => _schema != null; private protected DictionaryMemo _dictionaryMemo; private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo(); private protected readonly MemoryAllocator _allocator; private readonly ICompressionCodecFactory _compressionCodecFactory; + private protected Schema _schema; private protected ArrowReaderImplementation() : this(null, null) { } @@ -57,6 +69,9 @@ protected virtual void Dispose(bool disposing) { } + public abstract ValueTask ReadSchemaAsync(CancellationToken cancellationToken); + public abstract void ReadSchema(); + public abstract ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken); public abstract RecordBatch ReadNextRecordBatch(); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs index cdcfe7875da22..e129da399d59a 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs @@ -28,6 +28,9 @@ public class ArrowStreamReader : IArrowReader, IArrowArrayStream, IDisposable { private protected readonly ArrowReaderImplementation _implementation; + /// + /// May block if the schema hasn't yet been read. To avoid blocking, use GetSchemaAsync. + /// public Schema Schema => _implementation.Schema; public ArrowStreamReader(Stream stream) @@ -97,6 +100,15 @@ protected virtual void Dispose(bool disposing) } } + public async ValueTask GetSchema(CancellationToken cancellationToken = default) + { + if (!_implementation.HasReadSchema) + { + await _implementation.ReadSchemaAsync(cancellationToken); + } + return _implementation.Schema; + } + public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) { return _implementation.ReadNextRecordBatchAsync(cancellationToken); diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs index 5428c88c27bbc..5583a58487bf5 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs @@ -146,7 +146,7 @@ protected ReadResult ReadMessage() return new ReadResult(messageLength, result); } - protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) + public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default) { if (HasReadSchema) { @@ -164,11 +164,11 @@ protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellation EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); } } - protected virtual void ReadSchema() + public override void ReadSchema() { if (HasReadSchema) { @@ -184,7 +184,7 @@ protected virtual void ReadSchema() EnsureFullRead(buff, bytesRead); Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff); - Schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); + _schema = MessageSerializer.GetSchema(ReadMessage(schemabb), ref _dictionaryMemo); } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index 10315ff287c0b..2e7488092c2cf 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -38,6 +38,9 @@ public static void VerifyReader(ArrowStreamReader reader, RecordBatch originalBa public static async Task VerifyReaderAsync(ArrowStreamReader reader, RecordBatch originalBatch) { + Schema schema = await reader.GetSchema(); + Assert.NotNull(schema); + RecordBatch readBatch = await reader.ReadNextRecordBatchAsync(); CompareBatches(originalBatch, readBatch); diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs index ed030cc6ace11..b9e4664fdcd45 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs @@ -94,6 +94,8 @@ public async Task ReadRecordBatch_Memory(bool writeEnd) { await TestReaderFromMemory((reader, originalBatch) => { + Assert.NotNull(reader.Schema); + ArrowReaderVerifier.VerifyReader(reader, originalBatch); return Task.CompletedTask; }, writeEnd);