Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added SnowflakeDbDataReader implementation of GetEnumerator using DbEnumerator class #1031

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved.
*/

using System;
using System.Linq;
using System.Data.Common;
using System.Data;
using System.Globalization;
using System.Text;
using NUnit.Framework;
using Snowflake.Data.Client;
using Snowflake.Data.Core;
using Snowflake.Data.Tests.Util;

namespace Snowflake.Data.Tests.IntegrationTests
{
[TestFixture(ResultFormat.ARROW)]
[TestFixture(ResultFormat.JSON)]
class SFDbDataReaderGetEnumeratorIT : SFBaseTest
{
protected override string TestName => base.TestName + _resultFormat;

private readonly ResultFormat _resultFormat;

public SFDbDataReaderGetEnumeratorIT(ResultFormat resultFormat)
{
_resultFormat = resultFormat;
}

[Test]
public void TestGetEnumerator()
{
using (var conn = CreateAndOpenConnection())
{
CreateAndPopulateTestTable(conn);

string selectCommandText = $"select * from {TableName}";
IDbCommand selectCmd = conn.CreateCommand();
selectCmd.CommandText = selectCommandText;
DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader;

var enumerator = reader.GetEnumerator();
Assert.IsTrue(enumerator.MoveNext());
Assert.AreEqual(3, (enumerator.Current as DbDataRecord).GetInt64(0));
Assert.IsTrue(enumerator.MoveNext());
Assert.AreEqual(5, (enumerator.Current as DbDataRecord).GetInt64(0));
Assert.IsTrue(enumerator.MoveNext());
Assert.AreEqual(8, (enumerator.Current as DbDataRecord).GetInt64(0));
Assert.IsFalse(enumerator.MoveNext());

reader.Close();

DropTestTableAndCloseConnection(conn);
}
}

[Test]
public void TestGetEnumeratorShouldBeEmptyWhenNotRowsReturned()
{
using (var conn = CreateAndOpenConnection())
{
CreateAndPopulateTestTable(conn);

string selectCommandText = $"select * from {TableName} WHERE cola > 10";
IDbCommand selectCmd = conn.CreateCommand();
selectCmd.CommandText = selectCommandText;
DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader;

var enumerator = reader.GetEnumerator();
Assert.IsFalse(enumerator.MoveNext());
Assert.IsNull(enumerator.Current);

reader.Close();
DropTestTableAndCloseConnection(conn);
}
}

[Test]
public void TestGetEnumeratorWithCastMethod()
{
using (var conn = CreateAndOpenConnection())
{
CreateAndPopulateTestTable(conn);

string selectCommandText = $"select * from {TableName}";
IDbCommand selectCmd = conn.CreateCommand();
selectCmd.CommandText = selectCommandText;
DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader;

var dataRecords = reader.Cast<DbDataRecord>().ToList();
Assert.AreEqual(3, dataRecords.Count);

reader.Close();

DropTestTableAndCloseConnection(conn);
}
}

[Test]
public void TestGetEnumeratorForEachShouldNotEnterWhenResultsIsEmpty()
{
using (var conn = CreateAndOpenConnection())
{
CreateAndPopulateTestTable(conn);

string selectCommandText = $"select * from {TableName} WHERE cola > 10";
IDbCommand selectCmd = conn.CreateCommand();
selectCmd.CommandText = selectCommandText;
DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader;

foreach (var record in reader)
{
Assert.Fail("Should not enter when results is empty");
}

reader.Close();
DropTestTableAndCloseConnection(conn);
}
}

[Test]
public void TestGetEnumeratorShouldThrowNonSupportedExceptionWhenReset()
{
using (var conn = CreateAndOpenConnection())
{
CreateAndPopulateTestTable(conn);

string selectCommandText = $"select * from {TableName}";
IDbCommand selectCmd = conn.CreateCommand();
selectCmd.CommandText = selectCommandText;
DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader;

var enumerator = reader.GetEnumerator();
Assert.IsTrue(enumerator.MoveNext());

Assert.Throws<NotSupportedException>(() => enumerator.Reset());

reader.Close();

DropTestTableAndCloseConnection(conn);
}
}

private void DropTestTableAndCloseConnection(DbConnection conn)
{
IDbCommand cmd = conn.CreateCommand();
cmd.CommandText = $"drop table if exists {TableName}";
var count = cmd.ExecuteNonQuery();
Assert.AreEqual(0, count);

CloseConnection(conn);
}

private void CreateAndPopulateTestTable(DbConnection conn)
{
CreateOrReplaceTable(conn, TableName, new []{"cola NUMBER"});

var cmd = conn.CreateCommand();

string insertCommand = $"insert into {TableName} values (3),(5),(8)";
cmd.CommandText = insertCommand;
cmd.ExecuteNonQuery();
}

private DbConnection CreateAndOpenConnection()
{
var conn = new SnowflakeDbConnection(ConnectionString);
conn.Open();
SessionParameterAlterer.SetResultFormat(conn, _resultFormat);
return conn;
}

private void CloseConnection(DbConnection conn)
{
SessionParameterAlterer.RestoreResultFormat(conn);
conn.Close();
}
}
}
5 changes: 1 addition & 4 deletions Snowflake.Data/Client/SnowflakeDbDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,7 @@ public override double GetDouble(int ordinal)
return resultSet.GetDouble(ordinal);
}

public override IEnumerator GetEnumerator()
{
throw new NotImplementedException();
}
public override IEnumerator GetEnumerator() => new DbEnumerator(this, closeReader: false);

public override Type GetFieldType(int ordinal)
{
Expand Down
Loading