Skip to content

Commit

Permalink
apacheGH-40634: [C#] ArrowStreamReader should not be null (apache#40765)
Browse files Browse the repository at this point in the history
### What changes are included in this PR?

Small refactoring in the IPC reader implementation classes of how the schema is read in order to support getting the schema asynchronously through ArrowStreamReader and avoiding the case where ArrowStreamReader.Schema returns null because no record batches have yet been read.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

A new method ArrowStreamReader.GetSchema has been added to allow the schema to be gotten asynchronously.

Closes apache#40634 
* GitHub Issue: apache#40634

Authored-by: Curt Hagenlocher <[email protected]>
Signed-off-by: Curt Hagenlocher <[email protected]>
  • Loading branch information
CurtHagenlocher authored Mar 25, 2024
1 parent 8133a20 commit cc771a0
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ private protected FlightRecordBatchStreamReader(IAsyncStreamReader<Protocol.Flig
_arrowReaderImplementation = new RecordBatchReaderImplementation(flightDataStream);
}

public ValueTask<Schema> Schema => _arrowReaderImplementation.ReadSchema();
public ValueTask<Schema> Schema => _arrowReaderImplementation.GetSchemaAsync();

internal ValueTask<FlightDescriptor> GetFlightDescriptor()
{
return _arrowReaderImplementation.ReadFlightDescriptor();
}
}

/// <summary>
/// Get the application metadata from the latest received record batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,33 @@ public async ValueTask<FlightDescriptor> ReadFlightDescriptor()
{
if (!HasReadSchema)
{
await ReadSchema().ConfigureAwait(false);
await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false);
}
return _flightDescriptor;
}

public async ValueTask<Schema> ReadSchema()
public async ValueTask<Schema> GetSchemaAsync()
{
if (!HasReadSchema)
{
await ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false);
}
return _schema;
}

public override void ReadSchema()
{
ReadSchemaAsync(CancellationToken.None).AsTask().Wait();
}

public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken)
{
if (HasReadSchema)
{
return Schema;
return;
}

var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false);
var moveNextResult = await _flightDataStream.MoveNext(cancellationToken).ConfigureAwait(false);

if (!moveNextResult)
{
Expand All @@ -87,12 +101,11 @@ public async ValueTask<Schema> ReadSchema()
switch (message.HeaderType)
{
case MessageHeader.Schema:
Schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer);
_schema = FlightMessageSerializer.DecodeSchema(message.ByteBuffer);
break;
default:
throw new Exception($"Expected schema as the first message, but got: {message.HeaderType.ToString()}");
}
return Schema;
}

public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
Expand All @@ -101,7 +114,7 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati

if (!HasReadSchema)
{
await ReadSchema().ConfigureAwait(false);
await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);
}
var moveNextResult = await _flightDataStream.MoveNext().ConfigureAwait(false);
if (moveNextResult)
Expand Down
6 changes: 3 additions & 3 deletions csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public async ValueTask<int> RecordBatchCountAsync(CancellationToken cancellation
return _footer.RecordBatchCount;
}

protected override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default)
public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default)
{
if (HasReadSchema)
{
Expand Down Expand Up @@ -85,7 +85,7 @@ protected override async ValueTask ReadSchemaAsync(CancellationToken cancellatio
}
}

protected override void ReadSchema()
public override void ReadSchema()
{
if (HasReadSchema)
{
Expand Down Expand Up @@ -139,7 +139,7 @@ private void ReadSchema(Memory<byte> buffer)
// Deserialize the footer from the footer flatbuffer
_footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), ref _dictionaryMemo);

Schema = _footer.Schema;
_schema = _footer.Schema;
}

public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, CancellationToken cancellationToken)
Expand Down
11 changes: 9 additions & 2 deletions csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ public ArrowMemoryReaderImplementation(ReadOnlyMemory<byte> buffer, ICompression
_buffer = buffer;
}

public override ValueTask ReadSchemaAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
ReadSchema();
return default;
}

public override ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
Expand Down Expand Up @@ -93,7 +100,7 @@ public override RecordBatch ReadNextRecordBatch()
return batch;
}

private void ReadSchema()
public override void ReadSchema()
{
if (HasReadSchema)
{
Expand All @@ -117,7 +124,7 @@ private void ReadSchema()
}

ByteBuffer schemaBuffer = CreateByteBuffer(_buffer.Slice(_bufferPosition));
Schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemaBuffer), ref _dictionaryMemo);
_schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemaBuffer), ref _dictionaryMemo);
_bufferPosition += schemaMessageLength;
}
}
Expand Down
19 changes: 17 additions & 2 deletions csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,25 @@ namespace Apache.Arrow.Ipc
{
internal abstract class ArrowReaderImplementation : IDisposable
{
public Schema Schema { get; protected set; }
protected bool HasReadSchema => Schema != null;
public Schema Schema
{
get
{
if (!HasReadSchema)
{
ReadSchema();
}
return _schema;
}
}

protected internal bool HasReadSchema => _schema != null;

private protected DictionaryMemo _dictionaryMemo;
private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo();
private protected readonly MemoryAllocator _allocator;
private readonly ICompressionCodecFactory _compressionCodecFactory;
private protected Schema _schema;

private protected ArrowReaderImplementation() : this(null, null)
{ }
Expand All @@ -57,6 +69,9 @@ protected virtual void Dispose(bool disposing)
{
}

public abstract ValueTask ReadSchemaAsync(CancellationToken cancellationToken);
public abstract void ReadSchema();

public abstract ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken);
public abstract RecordBatch ReadNextRecordBatch();

Expand Down
12 changes: 12 additions & 0 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ public class ArrowStreamReader : IArrowReader, IArrowArrayStream, IDisposable
{
private protected readonly ArrowReaderImplementation _implementation;

/// <summary>
/// May block if the schema hasn't yet been read. To avoid blocking, use GetSchemaAsync.
/// </summary>
public Schema Schema => _implementation.Schema;

public ArrowStreamReader(Stream stream)
Expand Down Expand Up @@ -97,6 +100,15 @@ protected virtual void Dispose(bool disposing)
}
}

public async ValueTask<Schema> GetSchema(CancellationToken cancellationToken = default)
{
if (!_implementation.HasReadSchema)
{
await _implementation.ReadSchemaAsync(cancellationToken);
}
return _implementation.Schema;
}

public ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
return _implementation.ReadNextRecordBatchAsync(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ protected ReadResult ReadMessage()
return new ReadResult(messageLength, result);
}

protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default)
public override async ValueTask ReadSchemaAsync(CancellationToken cancellationToken = default)
{
if (HasReadSchema)
{
Expand All @@ -164,11 +164,11 @@ protected virtual async ValueTask ReadSchemaAsync(CancellationToken cancellation
EnsureFullRead(buff, bytesRead);

Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff);
Schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref _dictionaryMemo);
_schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref _dictionaryMemo);
}
}

protected virtual void ReadSchema()
public override void ReadSchema()
{
if (HasReadSchema)
{
Expand All @@ -184,7 +184,7 @@ protected virtual void ReadSchema()
EnsureFullRead(buff, bytesRead);

Google.FlatBuffers.ByteBuffer schemabb = CreateByteBuffer(buff);
Schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref _dictionaryMemo);
_schema = MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref _dictionaryMemo);
}
}

Expand Down
3 changes: 3 additions & 0 deletions csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public static void VerifyReader(ArrowStreamReader reader, RecordBatch originalBa

public static async Task VerifyReaderAsync(ArrowStreamReader reader, RecordBatch originalBatch)
{
Schema schema = await reader.GetSchema();
Assert.NotNull(schema);

RecordBatch readBatch = await reader.ReadNextRecordBatchAsync();
CompareBatches(originalBatch, readBatch);

Expand Down
2 changes: 2 additions & 0 deletions csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public async Task ReadRecordBatch_Memory(bool writeEnd)
{
await TestReaderFromMemory((reader, originalBatch) =>
{
Assert.NotNull(reader.Schema);

ArrowReaderVerifier.VerifyReader(reader, originalBatch);
return Task.CompletedTask;
}, writeEnd);
Expand Down

0 comments on commit cc771a0

Please sign in to comment.