Skip to content

Commit

Permalink
SNOW-979288: Add explicit DbType Parameter assignment (#889)
Browse files Browse the repository at this point in the history
### Description
Add explicit DbType Parameter assignment

### Checklist
- [ ] Code compiles correctly
- [ ] Code is formatted according to [Coding
Conventions](../CodingConventions.md)
- [ ] Created tests which fail without the change (if possible)
- [ ] All tests passing (`dotnet test`)
- [ ] Extended the README / documentation, if necessary
- [ ] Provide JIRA issue id (if possible) or GitHub issue id in PR name
  • Loading branch information
sfc-gh-ext-simba-lf authored Apr 1, 2024
1 parent 0303514 commit ac0860f
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 1 deletion.
64 changes: 64 additions & 0 deletions Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -769,5 +769,69 @@ public void testPutArrayBind1()
conn.Close();
}
}

[Test]
public void testExplicitDbTypeAssignmentForSimpleValue()
{

using (IDbConnection conn = new SnowflakeDbConnection())
{
conn.ConnectionString = ConnectionString;
conn.Open();

CreateOrReplaceTable(conn, TableName, new[]
{
"cola INTEGER",
});

using (IDbCommand cmd = conn.CreateCommand())
{
string insertCommand = $"insert into {TableName} values (?)";
cmd.CommandText = insertCommand;

var p1 = cmd.CreateParameter();
p1.ParameterName = "1";
p1.Value = 1;
cmd.Parameters.Add(p1);

var count = cmd.ExecuteNonQuery();
Assert.AreEqual(1, count);
}

conn.Close();
}
}

[Test]
public void testExplicitDbTypeAssignmentForArrayValue()
{

using (IDbConnection conn = new SnowflakeDbConnection())
{
conn.ConnectionString = ConnectionString;
conn.Open();

CreateOrReplaceTable(conn, TableName, new[]
{
"cola INTEGER",
});

using (IDbCommand cmd = conn.CreateCommand())
{
string insertCommand = $"insert into {TableName} values (?)";
cmd.CommandText = insertCommand;

var p1 = cmd.CreateParameter();
p1.ParameterName = "1";
p1.Value = new int[] { 1, 2, 3 };
cmd.Parameters.Add(p1);

var count = cmd.ExecuteNonQuery();
Assert.AreEqual(3, count);
}

conn.Close();
}
}
}
}
90 changes: 90 additions & 0 deletions Snowflake.Data.Tests/UnitTests/SFDbParameterTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ namespace Snowflake.Data.Tests
using NUnit.Framework;
using Snowflake.Data.Client;
using Snowflake.Data.Core;
using System;
using System.Data;
using System.Text;

[TestFixture]
class SFDbParameterTest
Expand Down Expand Up @@ -125,5 +127,93 @@ public void TestDbParameterResetDbType([Values] SFDataType expectedSFDataType)
_parameter.ResetDbType();
Assert.AreEqual(SFDataType.None, _parameter.SFDataType);
}

[Test]
public void TestDbTypeExplicitAssignment([Values] DbType expectedDbType)
{
_parameter = new SnowflakeDbParameter();

switch (expectedDbType)
{
case DbType.SByte:
_parameter.Value = new sbyte();
break;
case DbType.Byte:
_parameter.Value = new byte();
break;
case DbType.Int16:
_parameter.Value = new short();
break;
case DbType.Int32:
_parameter.Value = new int();
break;
case DbType.Int64:
_parameter.Value = new long();
break;
case DbType.UInt16:
_parameter.Value = new ushort();
break;
case DbType.UInt32:
_parameter.Value = new uint();
break;
case DbType.UInt64:
_parameter.Value = new ulong();
break;
case DbType.Decimal:
_parameter.Value = new decimal();
break;
case DbType.Boolean:
_parameter.Value = true;
break;
case DbType.Single:
_parameter.Value = new float();
break;
case DbType.Double:
_parameter.Value = new double();
break;
case DbType.Guid:
_parameter.Value = new Guid();
break;
case DbType.String:
_parameter.Value = "thisIsAString";
break;
case DbType.DateTime:
_parameter.Value = DateTime.Now;
break;
case DbType.DateTimeOffset:
_parameter.Value = DateTimeOffset.Now;
break;
case DbType.Binary:
_parameter.Value = Encoding.UTF8.GetBytes("BinaryData");
break;
case DbType.Object:
_parameter.Value = new object();
break;
default:
// Not supported
expectedDbType = default(DbType);
break;
}

Assert.AreEqual(expectedDbType, _parameter.DbType);
}

[Test]
public void TestDbTypeExplicitAssignmentWithNullValueAndDefaultDbType()
{
_parameter = new SnowflakeDbParameter();
_parameter.Value = null;
Assert.AreEqual(default(DbType), _parameter.DbType);
}

[Test]
public void TestDbTypeExplicitAssignmentWithNullValueAndNonDefaultDbType()
{
var nonDefaultDbType = DbType.String;
_parameter = new SnowflakeDbParameter();
_parameter.Value = null;
_parameter.DbType = nonDefaultDbType;
Assert.AreEqual(nonDefaultDbType, _parameter.DbType);
}
}
}
25 changes: 24 additions & 1 deletion Snowflake.Data/Client/SnowflakeDbParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ public class SnowflakeDbParameter : DbParameter

private SFDataType OriginType;

private DbType _dbType;

public SnowflakeDbParameter()
{
SFDataType = SFDataType.None;
Expand All @@ -34,7 +36,28 @@ public SnowflakeDbParameter(int ParameterIndex, SFDataType SFDataType)
this.SFDataType = SFDataType;
}

public override DbType DbType { get; set; }
public override DbType DbType
{
get
{
if (_dbType != default(DbType) || Value == null || Value is DBNull)
{
return _dbType;
}

var type = Value.GetType();
if (type.IsArray && type != typeof(byte[]))
{
return SFDataConverter.TypeToDbTypeMap[type.GetElementType()];
}
else
{
return SFDataConverter.TypeToDbTypeMap[type];
}
}

set => _dbType = value;
}

public override ParameterDirection Direction
{
Expand Down
25 changes: 25 additions & 0 deletions Snowflake.Data/Core/SFDataConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/

using System;
using System.Collections.Generic;
using System.Data;
using System.Globalization;
using System.Text;
Expand All @@ -20,6 +21,30 @@ static class SFDataConverter
{
internal static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);

internal static readonly Dictionary<Type, DbType> TypeToDbTypeMap = new Dictionary<Type, DbType>()
{
[typeof(byte)] = DbType.Byte,
[typeof(sbyte)] = DbType.SByte,
[typeof(short)] = DbType.Int16,
[typeof(ushort)] = DbType.UInt16,
[typeof(int)] = DbType.Int32,
[typeof(uint)] = DbType.UInt32,
[typeof(long)] = DbType.Int64,
[typeof(ulong)] = DbType.UInt64,
[typeof(float)] = DbType.Single,
[typeof(double)] = DbType.Double,
[typeof(decimal)] = DbType.Decimal,
[typeof(bool)] = DbType.Boolean,
[typeof(string)] = DbType.String,
[typeof(char)] = DbType.StringFixedLength,
[typeof(Guid)] = DbType.Guid,
[typeof(DateTime)] = DbType.DateTime,
[typeof(DateTimeOffset)] = DbType.DateTimeOffset,
[typeof(TimeSpan)] = DbType.Time,
[typeof(byte[])] = DbType.Binary,
[typeof(object)] = DbType.Object
};

internal static object ConvertToCSharpVal(UTF8Buffer srcVal, SFDataType srcType, Type destType)
{
if (srcVal == null)
Expand Down

0 comments on commit ac0860f

Please sign in to comment.