From ef6db7e4f5efb523e5064f0788d464dbb6338b0a Mon Sep 17 00:00:00 2001 From: Juan Martinez Ramirez <126511805+sfc-gh-jmartinezramirez@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:39:28 -0600 Subject: [PATCH] Added SnowflakeDbDataReader implementation of GetEnumerator using DbEnumerator class (#1031) --- .../SFDbDataReaderGetEnumeratorIT.cs | 180 ++++++++++++++++++ .../Client/SnowflakeDbDataReader.cs | 5 +- 2 files changed, 181 insertions(+), 4 deletions(-) create mode 100755 Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs new file mode 100755 index 000000000..88e25256e --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbDataReaderGetEnumeratorIT.cs @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Linq; +using System.Data.Common; +using System.Data; +using System.Globalization; +using System.Text; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture(ResultFormat.ARROW)] + [TestFixture(ResultFormat.JSON)] + class SFDbDataReaderGetEnumeratorIT : SFBaseTest + { + protected override string TestName => base.TestName + _resultFormat; + + private readonly ResultFormat _resultFormat; + + public SFDbDataReaderGetEnumeratorIT(ResultFormat resultFormat) + { + _resultFormat = resultFormat; + } + + [Test] + public void TestGetEnumerator() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName}"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var enumerator = reader.GetEnumerator(); + Assert.IsTrue(enumerator.MoveNext()); + Assert.AreEqual(3, (enumerator.Current as DbDataRecord).GetInt64(0)); + Assert.IsTrue(enumerator.MoveNext()); + Assert.AreEqual(5, (enumerator.Current as DbDataRecord).GetInt64(0)); + Assert.IsTrue(enumerator.MoveNext()); + Assert.AreEqual(8, (enumerator.Current as DbDataRecord).GetInt64(0)); + Assert.IsFalse(enumerator.MoveNext()); + + reader.Close(); + + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorShouldBeEmptyWhenNotRowsReturned() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName} WHERE cola > 10"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var enumerator = reader.GetEnumerator(); + Assert.IsFalse(enumerator.MoveNext()); + Assert.IsNull(enumerator.Current); + + reader.Close(); + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorWithCastMethod() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName}"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var dataRecords = reader.Cast().ToList(); + Assert.AreEqual(3, dataRecords.Count); + + reader.Close(); + + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorForEachShouldNotEnterWhenResultsIsEmpty() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName} WHERE cola > 10"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + foreach (var record in reader) + { + Assert.Fail("Should not enter when results is empty"); + } + + reader.Close(); + DropTestTableAndCloseConnection(conn); + } + } + + [Test] + public void TestGetEnumeratorShouldThrowNonSupportedExceptionWhenReset() + { + using (var conn = CreateAndOpenConnection()) + { + CreateAndPopulateTestTable(conn); + + string selectCommandText = $"select * from {TableName}"; + IDbCommand selectCmd = conn.CreateCommand(); + selectCmd.CommandText = selectCommandText; + DbDataReader reader = selectCmd.ExecuteReader() as DbDataReader; + + var enumerator = reader.GetEnumerator(); + Assert.IsTrue(enumerator.MoveNext()); + + Assert.Throws(() => enumerator.Reset()); + + reader.Close(); + + DropTestTableAndCloseConnection(conn); + } + } + + private void DropTestTableAndCloseConnection(DbConnection conn) + { + IDbCommand cmd = conn.CreateCommand(); + cmd.CommandText = $"drop table if exists {TableName}"; + var count = cmd.ExecuteNonQuery(); + Assert.AreEqual(0, count); + + CloseConnection(conn); + } + + private void CreateAndPopulateTestTable(DbConnection conn) + { + CreateOrReplaceTable(conn, TableName, new []{"cola NUMBER"}); + + var cmd = conn.CreateCommand(); + + string insertCommand = $"insert into {TableName} values (3),(5),(8)"; + cmd.CommandText = insertCommand; + cmd.ExecuteNonQuery(); + } + + private DbConnection CreateAndOpenConnection() + { + var conn = new SnowflakeDbConnection(ConnectionString); + conn.Open(); + SessionParameterAlterer.SetResultFormat(conn, _resultFormat); + return conn; + } + + private void CloseConnection(DbConnection conn) + { + SessionParameterAlterer.RestoreResultFormat(conn); + conn.Close(); + } + } +} diff --git a/Snowflake.Data/Client/SnowflakeDbDataReader.cs b/Snowflake.Data/Client/SnowflakeDbDataReader.cs index b7bc1615e..7d624bd80 100755 --- a/Snowflake.Data/Client/SnowflakeDbDataReader.cs +++ b/Snowflake.Data/Client/SnowflakeDbDataReader.cs @@ -189,10 +189,7 @@ public override double GetDouble(int ordinal) return resultSet.GetDouble(ordinal); } - public override IEnumerator GetEnumerator() - { - throw new NotImplementedException(); - } + public override IEnumerator GetEnumerator() => new DbEnumerator(this, closeReader: false); public override Type GetFieldType(int ordinal) {