Skip to content

Commit

Permalink
LongTotalLength
Browse files Browse the repository at this point in the history
  • Loading branch information
CurtHagenlocher committed Dec 11, 2023
1 parent c87faf8 commit 6e95567
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 16 deletions.
6 changes: 5 additions & 1 deletion csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ public async Task Write(RecordBatch recordBatch, ByteString applicationMetadata)
await _clientStreamWriter.WriteAsync(_currentFlightData).ConfigureAwait(false);
}

private protected override ValueTask<long> WriteMessageAsync<T>(MessageHeader headerType, Offset<T> headerOffset, int bodyLength, CancellationToken cancellationToken)
private protected override ValueTask<long> WriteMessageAsync<T>(
MessageHeader headerType,
Offset<T> headerOffset,
long bodyLength,
CancellationToken cancellationToken)
{
Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage(
Builder, CurrentMetadataVersion, headerType, headerOffset.Value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(Cancellati
{
case MessageHeader.RecordBatch:
var body = _flightDataStream.Current.DataBody.Memory;
return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, (int)message.BodyLength)), null);
return CreateArrowObjectFromMessage(message, CreateByteBuffer(body.Slice(0, checked((int)message.BodyLength))), null);
default:
throw new NotImplementedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public override RecordBatch ReadNextRecordBatch()
CreateByteBuffer(_buffer.Slice(_bufferPosition, messageLength)));
_bufferPosition += messageLength;

int bodyLength = (int)message.BodyLength;
int bodyLength = checked((int)message.BodyLength);
ByteBuffer bodybb = CreateByteBuffer(_buffer.Slice(_bufferPosition, bodyLength));
_bufferPosition += bodyLength;

Expand Down
14 changes: 7 additions & 7 deletions csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ protected RecordBatch CreateArrowObjectFromMessage(
case Flatbuf.MessageHeader.RecordBatch:
Flatbuf.RecordBatch rb = message.Header<Flatbuf.RecordBatch>().Value;
List<IArrowArray> arrays = BuildArrays(message.Version, Schema, bodyByteBuffer, rb);
return new RecordBatch(Schema, memoryOwner, arrays, (int)rb.Length);
return new RecordBatch(Schema, memoryOwner, arrays, checked((int)rb.Length));
default:
// NOTE: Skip unsupported message type
Debug.WriteLine($"Skipping unsupported message type '{message.HeaderType}'");
Expand Down Expand Up @@ -238,8 +238,8 @@ private ArrayData LoadPrimitiveField(
IBufferCreator bufferCreator)
{

int fieldLength = (int)fieldNode.Length;
int fieldNullCount = (int)fieldNode.NullCount;
int fieldLength = checked((int)fieldNode.Length);
int fieldNullCount = checked((int)fieldNode.NullCount);

if (fieldLength < 0)
{
Expand Down Expand Up @@ -322,8 +322,8 @@ private ArrayData LoadVariableField(
ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, recordBatchEnumerator.CurrentBuffer, bufferCreator);
recordBatchEnumerator.MoveNextBuffer();

int fieldLength = (int)fieldNode.Length;
int fieldNullCount = (int)fieldNode.NullCount;
int fieldLength = checked((int)fieldNode.Length);
int fieldNullCount = checked((int)fieldNode.NullCount);

if (fieldLength < 0)
{
Expand Down Expand Up @@ -381,8 +381,8 @@ private ArrowBuffer BuildArrowBuffer(ByteBuffer bodyData, Flatbuf.Buffer buffer,
return ArrowBuffer.Empty;
}

int offset = (int)buffer.Offset;
int length = (int)buffer.Length;
int offset = checked((int)buffer.Offset);
int length = checked((int)buffer.Length);

var data = bodyData.ToReadOnlyMemory(offset, length);
return bufferCreator.CreateBuffer(data);
Expand Down
12 changes: 6 additions & 6 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ internal class ArrowRecordBatchFlatBufferBuilder :
public readonly struct Buffer
{
public readonly ArrowBuffer DataBuffer;
public readonly int Offset;
public readonly long Offset;

public Buffer(ArrowBuffer buffer, int offset)
public Buffer(ArrowBuffer buffer, long offset)
{
DataBuffer = buffer;
Offset = offset;
Expand All @@ -81,7 +81,7 @@ public Buffer(ArrowBuffer buffer, int offset)

public IReadOnlyList<Buffer> Buffers => _buffers;

public int TotalLength { get; private set; }
public long TotalLength { get; private set; }

public ArrowRecordBatchFlatBufferBuilder()
{
Expand Down Expand Up @@ -210,7 +210,7 @@ private void CreateBuffers<T>(PrimitiveArray<T> array)

private Buffer CreateBuffer(ArrowBuffer buffer)
{
int offset = TotalLength;
long offset = TotalLength;

int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length));
TotalLength += paddedLength;
Expand Down Expand Up @@ -819,7 +819,7 @@ await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellat
/// The number of bytes written to the stream.
/// </returns>
private protected long WriteMessage<T>(
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength)
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, long bodyLength)
where T : struct
{
Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage(
Expand Down Expand Up @@ -849,7 +849,7 @@ private protected long WriteMessage<T>(
/// The number of bytes written to the stream.
/// </returns>
private protected virtual async ValueTask<long> WriteMessageAsync<T>(
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, long bodyLength,
CancellationToken cancellationToken)
where T : struct
{
Expand Down
53 changes: 53 additions & 0 deletions csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
// limitations under the License.

using Apache.Arrow.Ipc;
using Apache.Arrow.Types;
using System;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Xunit;

Expand Down Expand Up @@ -164,5 +166,56 @@ public async Task WritesEmptyFileAsync()
Assert.Null(readBatch);
SchemaComparer.Compare(originalBatch.Schema, reader.Schema);
}

[Fact]
public async Task WritesLargeFileAsync()
{
int rowCount = 1_000_000;
int batchCount = 25;
int columns = 25;

Int64Array manyZeroes = new Int64Array(new ArrowBuffer(new byte[8 * rowCount]), ArrowBuffer.Empty, rowCount, 0, 0);
Schema schema = new Schema(Enumerable.Repeat(new Field("", Int64Type.Default, false), columns).ToArray(), null);
RecordBatch batch = new RecordBatch(schema, Enumerable.Repeat(manyZeroes, columns), rowCount);

string tempFile = Path.GetTempFileName();
try
{
using (Stream stream = File.OpenWrite(tempFile))
using (ArrowFileWriter writer = new ArrowFileWriter(stream, schema))
{
await writer.WriteStartAsync();
for (int i = 0; i < batchCount; i++)
{
await writer.WriteRecordBatchAsync(batch);
}
await writer.WriteEndAsync();
}

using (Stream stream = File.OpenRead(tempFile))
using (ArrowFileReader reader = new ArrowFileReader(stream))
{
int readCount = await reader.RecordBatchCountAsync();
Assert.Equal(batchCount, readCount);

for (int i = 0; i < batchCount; i++)
{
RecordBatch readBatch = await reader.ReadNextRecordBatchAsync();
Assert.Equal(rowCount, readBatch.Length);
Assert.Equal(columns, readBatch.Schema.FieldsList.Count);
}
}
}
finally
{
try
{
File.Delete(tempFile);
}
catch (Exception)
{
}
}
}
}
}

0 comments on commit 6e95567

Please sign in to comment.