Skip to content

Commit

Permalink
Add SqlDecimal support to Decimal128Array
Browse files Browse the repository at this point in the history
  • Loading branch information
CurtHagenlocher committed Oct 26, 2023
1 parent 57f643c commit e77131d
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 8 deletions.
40 changes: 40 additions & 0 deletions csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

using System;
using System.Collections.Generic;
#if !NETSTANDARD1_3
using System.Data.SqlTypes;
#endif
using System.Diagnostics;
using System.Numerics;
using Apache.Arrow.Arrays;
Expand Down Expand Up @@ -61,6 +64,31 @@ public Builder AppendRange(IEnumerable<decimal> values)
return Instance;
}

#if !NETSTANDARD1_3
public Builder Append(SqlDecimal value)
{
Span<byte> bytes = stackalloc byte[DataType.ByteWidth];
DecimalUtility.GetBytes(value, DataType.Precision, DataType.Scale, bytes);

return Append(bytes);
}

public Builder AppendRange(IEnumerable<SqlDecimal> values)
{
if (values == null)
{
throw new ArgumentNullException(nameof(values));
}

foreach (decimal d in values)
{
Append(d);
}

return Instance;
}
#endif

public Builder Set(int index, decimal value)
{
Span<byte> bytes = stackalloc byte[DataType.ByteWidth];
Expand Down Expand Up @@ -91,5 +119,17 @@ public Decimal128Array(ArrayData data)
}
return DecimalUtility.GetDecimal(ValueBuffer, index, Scale, ByteWidth);
}

#if !NETSTANDARD1_3
public SqlDecimal? GetSqlDecimal(int index)
{
if (IsNull(index))
{
return null;
}

return DecimalUtility.GetSqlDecimal128(ValueBuffer, index, Precision, Scale);
}
#endif
}
}
45 changes: 45 additions & 0 deletions csharp/src/Apache.Arrow/DecimalUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
// limitations under the License.

using System;
#if !NETSTANDARD1_3
using System.Data.SqlTypes;
#endif
using System.Numerics;

namespace Apache.Arrow
Expand Down Expand Up @@ -73,6 +76,27 @@ internal static decimal GetDecimal(in ArrowBuffer valueBuffer, int index, int sc
}
}

#if !NETSTANDARD1_3
internal static SqlDecimal GetSqlDecimal128(in ArrowBuffer valueBuffer, int index, int precision, int scale)
{
const int byteWidth = 16;
const int intWidth = byteWidth / 4;

byte mostSignificantByte = valueBuffer.Span[(index + 1) * byteWidth - 1];
bool isPositive = (mostSignificantByte & 0x80) == 0;

ReadOnlySpan<int> value = valueBuffer.Span.CastTo<int>().Slice(index * intWidth, intWidth);
if (isPositive)
{
return new SqlDecimal((byte)precision, (byte)scale, true, value[0], value[1], value[2], value[3]);
}
else
{
return new SqlDecimal((byte)precision, (byte)scale, false, -value[0], ~value[1], ~value[2], ~value[3]);
}
}
#endif

private static decimal DivideByScale(BigInteger integerValue, int scale)
{
decimal result = (decimal)integerValue; // this cast is safe here
Expand Down Expand Up @@ -169,5 +193,26 @@ internal static void GetBytes(decimal value, int precision, int scale, int byteW
}
}
}

#if !NETSTANDARD1_3
internal static void GetBytes(SqlDecimal value, int precision, int scale, Span<byte> bytes)
{
if (value.Precision != precision || value.Scale != scale)
{
value = SqlDecimal.ConvertToPrecScale(value, precision, scale);
}

// TODO: Consider groveling in the internals to avoid the probable allocation
Span<int> span = bytes.CastTo<int>();
value.Data.AsSpan().CopyTo(span);
if (!value.IsPositive)
{
span[0] = -span[0];
span[1] = ~span[1];
span[2] = ~span[2];
span[3] = ~span[3];
}
}
#endif
}
}
117 changes: 109 additions & 8 deletions csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,28 @@
// limitations under the License.

using System;
using System.Collections.Generic;
#if !NETSTANDARD1_3
using System.Data.SqlTypes;
#endif
using Apache.Arrow.Types;
using Xunit;

namespace Apache.Arrow.Tests
{
public class Decimal128ArrayTests
{
#if !NETSTANDARD1_3
static SqlDecimal? Convert(decimal? value)
{
return value == null ? null : new SqlDecimal(value.Value);
}

static decimal? Convert(SqlDecimal? value)
{
return value == null ? null : value.Value.Value;
}
#endif

public class Builder
{
public class AppendNull
Expand All @@ -30,7 +44,7 @@ public class AppendNull
public void AppendThenGetGivesNull()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(8,2));
var builder = new Decimal128Array.Builder(new Decimal128Type(8, 2));

// Act

Expand All @@ -45,6 +59,12 @@ public void AppendThenGetGivesNull()
Assert.Null(array.GetValue(0));
Assert.Null(array.GetValue(1));
Assert.Null(array.GetValue(2));

#if !NETSTANDARD1_3
Assert.Null(array.GetSqlDecimal(0));
Assert.Null(array.GetSqlDecimal(1));
Assert.Null(array.GetSqlDecimal(2));
#endif
}
}

Expand All @@ -67,7 +87,7 @@ public void AppendDecimal(int count)
testData[i] = null;
continue;
}
decimal rnd = i * (decimal)Math.Round(new Random().NextDouble(),10);
decimal rnd = i * (decimal)Math.Round(new Random().NextDouble(), 10);
testData[i] = rnd;
builder.Append(rnd);
}
Expand All @@ -78,6 +98,9 @@ public void AppendDecimal(int count)
for (int i = 0; i < count; i++)
{
Assert.Equal(testData[i], array.GetValue(i));
#if !NETSTANDARD1_3
Assert.Equal(Convert(testData[i]), array.GetSqlDecimal(i));
#endif
}
}

Expand All @@ -95,6 +118,11 @@ public void AppendLargeDecimal()
var array = builder.Build();
Assert.Equal(large, array.GetValue(0));
Assert.Equal(-large, array.GetValue(1));

#if !NETSTANDARD1_3
Assert.Equal(Convert(large), array.GetSqlDecimal(0));
Assert.Equal(Convert(-large), array.GetSqlDecimal(1));
#endif
}

[Fact]
Expand All @@ -115,6 +143,13 @@ public void AppendMaxAndMinDecimal()
Assert.Equal(Decimal.MinValue, array.GetValue(1));
Assert.Equal(Decimal.MaxValue - 10, array.GetValue(2));
Assert.Equal(Decimal.MinValue + 10, array.GetValue(3));

#if !NETSTANDARD1_3
Assert.Equal(Convert(Decimal.MaxValue), array.GetSqlDecimal(0));
Assert.Equal(Convert(Decimal.MinValue), array.GetSqlDecimal(1));
Assert.Equal(Convert(Decimal.MaxValue) - 10, array.GetSqlDecimal(2));
Assert.Equal(Convert(Decimal.MinValue) + 10, array.GetSqlDecimal(3));
#endif
}

[Fact]
Expand All @@ -131,35 +166,43 @@ public void AppendFractionalDecimal()
var array = builder.Build();
Assert.Equal(fraction, array.GetValue(0));
Assert.Equal(-fraction, array.GetValue(1));

#if !NETSTANDARD1_3
Assert.Equal(Convert(fraction), array.GetSqlDecimal(0));
Assert.Equal(Convert(-fraction), array.GetSqlDecimal(1));
#endif
}

[Fact]
public void AppendRangeDecimal()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(24, 8));
var range = new decimal[] {2.123M, 1.5984M, -0.0000001M, 9878987987987987.1235407M};
var range = new decimal[] { 2.123M, 1.5984M, -0.0000001M, 9878987987987987.1235407M };

// Act
builder.AppendRange(range);
builder.AppendNull();

// Assert
var array = builder.Build();
for(int i = 0; i < range.Length; i ++)
for (int i = 0; i < range.Length; i++)
{
Assert.Equal(range[i], array.GetValue(i));
#if !NETSTANDARD1_3
Assert.Equal(Convert(range[i]), array.GetSqlDecimal(i));
#endif
}
Assert.Null( array.GetValue(range.Length));

Assert.Null(array.GetValue(range.Length));
}

[Fact]
public void AppendClearAppendDecimal()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(24, 8));

// Act
builder.Append(1);
builder.Clear();
Expand Down Expand Up @@ -256,6 +299,64 @@ public void SwapNull()
Assert.Equal(123.456M, array.GetValue(1));
}
}

#if !NETSTANDARD1_3
public class SqlDecimals
{
[Theory]
[InlineData(200)]
public void AppendSqlDecimal(int count)
{
// Arrange
const int precision = 10;
var builder = new Decimal128Array.Builder(new Decimal128Type(14, precision));

// Act
SqlDecimal?[] testData = new SqlDecimal?[count];
for (int i = 0; i < count; i++)
{
if (i == count - 2)
{
builder.AppendNull();
testData[i] = null;
continue;
}
SqlDecimal rnd = i * (SqlDecimal)Math.Round(new Random().NextDouble(), 10);
builder.Append(rnd);
testData[i] = SqlDecimal.Round(rnd, precision);
}

// Assert
var array = builder.Build();
Assert.Equal(count, array.Length);
for (int i = 0; i < count; i++)
{
Assert.Equal(testData[i], array.GetSqlDecimal(i));
Assert.Equal(Convert(testData[i]), array.GetValue(i));
}
}

[Fact]
public void AppendMaxAndMinSqlDecimal()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(38, 0));

// Act
builder.Append(SqlDecimal.MaxValue);
builder.Append(SqlDecimal.MinValue);
builder.Append(SqlDecimal.MaxValue - 10);
builder.Append(SqlDecimal.MinValue + 10);

// Assert
var array = builder.Build();
Assert.Equal(SqlDecimal.MaxValue, array.GetSqlDecimal(0));
Assert.Equal(SqlDecimal.MinValue, array.GetSqlDecimal(1));
Assert.Equal(SqlDecimal.MaxValue - 10, array.GetSqlDecimal(2));
Assert.Equal(SqlDecimal.MinValue + 10, array.GetSqlDecimal(3));
}
}
#endif
}
}
}

0 comments on commit e77131d

Please sign in to comment.