From 62358d41214d395068cf9f3ce7f923723385d90b Mon Sep 17 00:00:00 2001 From: sfc-gh-fgonzalezmendez Date: Wed, 21 Aug 2024 11:20:18 -0600 Subject: [PATCH 1/7] Add support for com.snowflake.snowpark.Row.getAs function --- .../java/com/snowflake/snowpark_java/Row.java | 58 +++++++++ .../scala/com/snowflake/snowpark/Row.scala | 40 +++++- .../snowflake/snowpark_test/JavaRowSuite.java | 104 ++++++++++++++++ .../snowflake/snowpark_test/RowSuite.scala | 114 +++++++++++++++++- 4 files changed, 309 insertions(+), 7 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 0921a0d6..40959927 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -425,6 +425,64 @@ public Row getObject(int index) { return (Row) get(index); } + /** + * Returns the value at the specified column index and casts it to the desired type {@code T}. + * + *

Example: + * + *

{@code
+   * Row row = Row.create(1, "Alice", 95.5);
+   * row.getAs(0, Integer.class); // Returns 1 as an Int
+   * row.getAs(1, String.class); // Returns "Alice" as a String
+   * row.getAs(2, Double.class); // Returns 95.5 as a Double
+   * }
+ * + * @param index the zero-based column index within the row. + * @param clazz the {@code Class} object representing the type {@code T}. + * @param the expected type of the value at the specified column index. + * @return the value at the specified column index cast to type {@code T}. + * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}. + * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. + * @since 1.14.0 + */ + @SuppressWarnings("unchecked") + public T getAs(int index, Class clazz) + throws ClassCastException, ArrayIndexOutOfBoundsException { + if (isNullAt(index)) { + return (T) get(index); + } + + if (clazz == Byte.class) { + return (T) (Object) getByte(index); + } + + if (clazz == Double.class) { + return (T) (Object) getDouble(index); + } + + if (clazz == Float.class) { + return (T) (Object) getFloat(index); + } + + if (clazz == Integer.class) { + return (T) (Object) getInt(index); + } + + if (clazz == Long.class) { + return (T) (Object) getLong(index); + } + + if (clazz == Short.class) { + return (T) (Object) getShort(index); + } + + if (clazz == Variant.class) { + return (T) getVariant(index); + } + + return (T) get(index); + } + /** * Generates a string value to represent the content of this row. * diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index a1dc5aef..64a5db36 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable { getAs[Map[T, U]](index) } + /** + * Returns the value at the specified column index and casts it to the desired type `T`. + * + * Example: + * {{{ + * val row = Row(1, "Alice", 95.5) + * row.getAs[Int](0) // Returns 1 as an Int + * row.getAs[String](1) // Returns "Alice" as a String + * row.getAs[Double](2) // Returns 95.5 as a Double + * }}} + * + * @param index the zero-based column index within the row. + * @tparam T the expected type of the value at the specified column index. + * @return the value at the specified column index cast to type `T`. + * @throws ClassCastException if the value at the given index cannot be cast to type `T`. + * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. + * @group getter + * @since 1.14.0 + */ + def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = { + classTag.runtimeClass match { + case _ if isNullAt(index) => get(index).asInstanceOf[T] + case c if c == classOf[Byte] => getByte(index).asInstanceOf[T] + case c if c == classOf[Double] => getDouble(index).asInstanceOf[T] + case c if c == classOf[Float] => getFloat(index).asInstanceOf[T] + case c if c == classOf[Int] => getInt(index).asInstanceOf[T] + case c if c == classOf[Long] => getLong(index).asInstanceOf[T] + case c if c == classOf[Short] => getShort(index).asInstanceOf[T] + case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T] + case _ => get(index).asInstanceOf[T] + } + } + protected def convertValueToString(value: Any): String = value match { case null => "null" @@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable { .map(convertValueToString) .mkString("Row[", ",", "]") - private def getAs[T](index: Int): T = get(index).asInstanceOf[T] - - private def getAnyValAs[T <: AnyVal](index: Int): T = + private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T = if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null") - else getAs[T](index) - + else getAs[T](index)(classTag) } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index 0570c421..df15f438 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -1,5 +1,7 @@ package com.snowflake.snowpark_test; +import static org.junit.Assert.assertThrows; + import com.snowflake.snowpark_java.DataFrame; import com.snowflake.snowpark_java.Row; import com.snowflake.snowpark_java.types.*; @@ -429,4 +431,106 @@ public void testGetRow() { }, getSession()); } + + @Test + public void getAs() { + long milliseconds = System.currentTimeMillis(); + + StructType schema = + StructType.create( + new StructField("c01", DataTypes.BinaryType), + new StructField("c02", DataTypes.BooleanType), + new StructField("c03", DataTypes.ByteType), + new StructField("c04", DataTypes.DateType), + new StructField("c05", DataTypes.DoubleType), + new StructField("c06", DataTypes.FloatType), + new StructField("c07", DataTypes.GeographyType), + new StructField("c08", DataTypes.GeometryType), + new StructField("c09", DataTypes.IntegerType), + new StructField("c10", DataTypes.LongType), + new StructField("c11", DataTypes.ShortType), + new StructField("c12", DataTypes.StringType), + new StructField("c13", DataTypes.TimeType), + new StructField("c14", DataTypes.TimestampType), + new StructField("c15", DataTypes.VariantType)); + + Row[] data = { + Row.create( + new byte[] {1, 2}, + true, + Byte.MIN_VALUE, + Date.valueOf("2024-01-01"), + Double.MIN_VALUE, + Float.MIN_VALUE, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Integer.MIN_VALUE, + Long.MIN_VALUE, + Short.MIN_VALUE, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1)) + }; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2}); + assert row.getAs(1, Boolean.class); + assert row.getAs(2, Byte.class) == Byte.MIN_VALUE; + assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01")); + assert row.getAs(4, Double.class) == Double.MIN_VALUE; + assert row.getAs(5, Float.class) == Float.MIN_VALUE; + assert row.getAs(6, Geography.class) + .equals( + Geography.fromGeoJSON( + "{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(7, Geometry.class) + .equals( + Geometry.fromGeoJSON( + "{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(8, Integer.class) == Integer.MIN_VALUE; + assert row.getAs(9, Long.class) == Long.MIN_VALUE; + assert row.getAs(10, Short.class) == Short.MIN_VALUE; + assert row.getAs(11, String.class).equals("string"); + assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04")); + assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds)); + assert row.getAs(14, Variant.class).equals(new Variant(1)); + + Row finalRow = row; + assertThrows( + ClassCastException.class, + () -> { + Boolean b = finalRow.getAs(0, Boolean.class); + }); + assertThrows( + ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class)); + + data = + new Row[] { + Row.create( + null, null, null, null, null, null, null, null, null, null, null, null, null, null, + null) + }; + + df = getSession().createDataFrame(data, schema); + row = df.collect()[0]; + + assert row.getAs(0, byte[].class) == null; + assert row.getAs(1, Boolean.class) == null; + assert row.getAs(2, Byte.class) == null; + assert row.getAs(3, Date.class) == null; + assert row.getAs(4, Double.class) == null; + assert row.getAs(5, Float.class) == null; + assert row.getAs(6, Geography.class) == null; + assert row.getAs(7, Geometry.class) == null; + assert row.getAs(8, Integer.class) == null; + assert row.getAs(9, Long.class) == null; + assert row.getAs(10, Short.class) == null; + assert row.getAs(11, String.class) == null; + assert row.getAs(12, Time.class) == null; + assert row.getAs(13, Timestamp.class) == null; + assert row.getAs(14, Variant.class) == null; + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 54aba687..5491f6f1 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -1,9 +1,9 @@ package com.snowflake.snowpark_test -import com.snowflake.snowpark.types.{Geography, Geometry, Variant} +import com.snowflake.snowpark.types._ import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException} -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util @@ -233,6 +233,116 @@ class RowSuite extends SNTestBase { assert(err.message.matches(msg)) } + test("getAs") { + val milliseconds = System.currentTimeMillis() + + val schema = StructType( + Seq( + StructField("c01", BinaryType), + StructField("c02", BooleanType), + StructField("c03", ByteType), + StructField("c04", DateType), + StructField("c05", DoubleType), + StructField("c06", FloatType), + StructField("c07", GeographyType), + StructField("c08", GeometryType), + StructField("c09", IntegerType), + StructField("c10", LongType), + StructField("c11", ShortType), + StructField("c12", StringType), + StructField("c13", TimeType), + StructField("c14", TimestampType), + StructField("c15", VariantType))) + + var data = Seq( + Row( + Array[Byte](1, 9), + true, + Byte.MinValue, + Date.valueOf("2024-01-01"), + Double.MinValue, + Float.MinPositiveValue, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Int.MinValue, + Long.MinValue, + Short.MinValue, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1))) + + var df = session.createDataFrame(data, schema) + var row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9)) + assert(row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == Byte.MinValue) + assert(row.getAs[Date](3) == Date.valueOf("2024-01-01")) + assert(row.getAs[Double](4) == Double.MinValue) + assert(row.getAs[Float](5) == Float.MinPositiveValue) + assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{ + | "coordinates": [ + | 30, + | 10 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{ + | "coordinates": [ + | 2.000000000000000e+01, + | 4.000000000000000e+01 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Int](8) == Int.MinValue) + assert(row.getAs[Long](9) == Long.MinValue) + assert(row.getAs[Short](10) == Short.MinValue) + assert(row.getAs[String](11) == "string") + assert(row.getAs[Time](12) == Time.valueOf("16:23:04")) + assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds)) + assert(row.getAs[Variant](14) == new Variant(1)) + assertThrows[ClassCastException](row.getAs[Boolean](0)) + assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1)) + + data = Seq( + Row( + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null)) + + df = session.createDataFrame(data, schema) + row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) == null) + assert(!row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == 0) + assert(row.getAs[Date](3) == null) + assert(row.getAs[Double](4) == 0) + assert(row.getAs[Float](5) == 0) + assert(row.getAs[Geography](6) == null) + assert(row.getAs[Geometry](7) == null) + assert(row.getAs[Int](8) == 0) + assert(row.getAs[Long](9) == 0) + assert(row.getAs[Short](10) == 0) + assert(row.getAs[String](11) == null) + assert(row.getAs[Time](12) == null) + assert(row.getAs[Timestamp](13) == null) + assert(row.getAs[Variant](14) == null) + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3) From ccc70b5be19e910c2ec68aca9b07a7490e66329d Mon Sep 17 00:00:00 2001 From: sfc-gh-fgonzalezmendez Date: Wed, 21 Aug 2024 12:19:07 -0600 Subject: [PATCH 2/7] Fix code formatting --- src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index df15f438..c57ab9a2 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -504,8 +504,7 @@ public void getAs() { () -> { Boolean b = finalRow.getAs(0, Boolean.class); }); - assertThrows( - ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class)); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class)); data = new Row[] { From da145407337bec73c767363c4a5667da70ece34b Mon Sep 17 00:00:00 2001 From: Fabian Gonzalez Mendez Date: Fri, 30 Aug 2024 17:10:31 -0600 Subject: [PATCH 3/7] Add tests scenarios for structured maps and structured arrays --- .../snowflake/snowpark_test/JavaRowSuite.java | 61 +++++++++++++++++++ .../snowflake/snowpark_test/RowSuite.scala | 56 +++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index c57ab9a2..0c046630 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -9,10 +9,12 @@ import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.TimeZone; import org.junit.Test; public class JavaRowSuite extends TestBase { @@ -532,4 +534,63 @@ public void getAs() { assert row.getAs(13, Timestamp.class) == null; assert row.getAs(14, Variant.class) == null; } + + @Test + public void getAsWithStructuredMap() { + structuredTypeTest( + () -> { + String query = + "SELECT " + + "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1," + + "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2," + + "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + Map map1 = row.getAs(0, Map.class); + assert (Long) map1.get("a") == 1L; + assert (Long) map1.get("b") == 2L; + + Map map2 = row.getAs(1, Map.class); + assert map2.get(1L).equals("a"); + assert map2.get(2L).equals("b"); + + Map map3 = row.getAs(2, Map.class); + assert map3.get(1L).equals(Map.of("a", 1L, "b", 2L)); + assert map3.get(2L).equals(Map.of("c", 3L)); + }, + getSession()); + } + + @Test + public void getAsWithStructuredArray() { + structuredTypeTest( + () -> { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")); + + String query = + "SELECT " + + "[1,2,3]::ARRAY(NUMBER) AS arr1," + + "['a','b']::ARRAY(VARCHAR) AS arr2," + + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3," + + "[[1,2]]::ARRAY(ARRAY) AS arr4"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + ArrayList array1 = row.getAs(0, ArrayList.class); + assert array1.equals(List.of(1L, 2L, 3L)); + + ArrayList array2 = row.getAs(1, ArrayList.class); + assert array2.equals(List.of("a", "b")); + + ArrayList array3 = row.getAs(2, ArrayList.class); + assert array3.equals(List.of(new Timestamp(31000000000L))); + + ArrayList array4 = row.getAs(3, ArrayList.class); + assert array4.equals(List.of("[\n 1,\n 2\n]")); + }, + getSession()); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 5491f6f1..5ba449dd 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -6,6 +6,7 @@ import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException} import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util +import java.util.TimeZone class RowSuite extends SNTestBase { @@ -343,6 +344,61 @@ class RowSuite extends SNTestBase { assert(row.getAs[Variant](14) == null) } + test("getAs with structured map") { + structuredTypeTest { + val query = + """SELECT + | {'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1, + | {'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2, + | {'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val map1 = row.getAs[Map[String, Long]](0) + assert(map1("a") == 1L) + assert(map1("b") == 2L) + + val map2 = row.getAs[Map[Long, String]](1) + assert(map2(1) == "a") + assert(map2(2) == "b") + + val map3 = row.getAs[Map[Long, Map[String, Long]]](2) + assert(map3(1) == Map("a" -> 1, "b" -> 2)) + assert(map3(2) == Map("c" -> 3)) + } + } + + test("getAs with structured array") { + structuredTypeTest { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")) + + val query = + """SELECT + | [1,2,3]::ARRAY(NUMBER) AS arr1, + | ['a','b']::ARRAY(VARCHAR) AS arr2, + | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3, + | [[1,2]]::ARRAY(ARRAY) AS arr4 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val array1 = row.getAs[Array[Object]](0) + assert(array1 sameElements Array(1, 2, 3)) + + val array2 = row.getAs[Array[Object]](1) + assert(array2 sameElements Array("a", "b")) + + val array3 = row.getAs[Array[Object]](2) + assert(array3 sameElements Array(new Timestamp(31000000000L))) + + val array4 = row.getAs[Array[Object]](3) + assert(array4 sameElements Array("[\n 1,\n 2\n]")) + } + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3) From 29ff272514764594059d3a6d9525da657f6cedad Mon Sep 17 00:00:00 2001 From: Fabian Gonzalez Mendez Date: Fri, 30 Aug 2024 19:14:34 -0600 Subject: [PATCH 4/7] Use collections compatible with Java 8 --- .../snowflake/snowpark_test/JavaRowSuite.java | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index 0c046630..6ef5c4f1 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -11,6 +11,7 @@ import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -557,8 +558,11 @@ public void getAsWithStructuredMap() { assert map2.get(2L).equals("b"); Map map3 = row.getAs(2, Map.class); - assert map3.get(1L).equals(Map.of("a", 1L, "b", 2L)); - assert map3.get(2L).equals(Map.of("c", 3L)); + Map map3ExpectedInnerMap = new HashMap<>(); + map3ExpectedInnerMap.put("a", 1L); + map3ExpectedInnerMap.put("b", 2L); + assert map3.get(1L).equals(map3ExpectedInnerMap); + assert map3.get(2L).equals(Collections.singletonMap("c", 3L)); }, getSession()); } @@ -580,16 +584,16 @@ public void getAsWithStructuredArray() { Row row = df.collect()[0]; ArrayList array1 = row.getAs(0, ArrayList.class); - assert array1.equals(List.of(1L, 2L, 3L)); + assert array1.equals(Arrays.asList(1L, 2L, 3L)); ArrayList array2 = row.getAs(1, ArrayList.class); - assert array2.equals(List.of("a", "b")); + assert array2.equals(Arrays.asList("a", "b")); ArrayList array3 = row.getAs(2, ArrayList.class); - assert array3.equals(List.of(new Timestamp(31000000000L))); + assert array3.equals(Collections.singletonList(new Timestamp(31000000000L))); ArrayList array4 = row.getAs(3, ArrayList.class); - assert array4.equals(List.of("[\n 1,\n 2\n]")); + assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]")); }, getSession()); } From 741cb73d6f21f97535efa8f66715ba55ca96bacd Mon Sep 17 00:00:00 2001 From: Fabian Gonzalez Mendez Date: Mon, 2 Sep 2024 09:51:02 -0600 Subject: [PATCH 5/7] Update tests for structured types --- .../snowflake/snowpark_test/JavaRowSuite.java | 51 ++++++++++-------- .../snowflake/snowpark_test/RowSuite.scala | 53 ++++++++++--------- 2 files changed, 57 insertions(+), 47 deletions(-) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index 6ef5c4f1..938290f7 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -571,29 +571,34 @@ public void getAsWithStructuredMap() { public void getAsWithStructuredArray() { structuredTypeTest( () -> { - TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")); - - String query = - "SELECT " - + "[1,2,3]::ARRAY(NUMBER) AS arr1," - + "['a','b']::ARRAY(VARCHAR) AS arr2," - + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3," - + "[[1,2]]::ARRAY(ARRAY) AS arr4"; - - DataFrame df = getSession().sql(query); - Row row = df.collect()[0]; - - ArrayList array1 = row.getAs(0, ArrayList.class); - assert array1.equals(Arrays.asList(1L, 2L, 3L)); - - ArrayList array2 = row.getAs(1, ArrayList.class); - assert array2.equals(Arrays.asList("a", "b")); - - ArrayList array3 = row.getAs(2, ArrayList.class); - assert array3.equals(Collections.singletonList(new Timestamp(31000000000L))); - - ArrayList array4 = row.getAs(3, ArrayList.class); - assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]")); + var oldTimeZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")); + + String query = + "SELECT " + + "[1,2,3]::ARRAY(NUMBER) AS arr1," + + "['a','b']::ARRAY(VARCHAR) AS arr2," + + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3," + + "[[1,2]]::ARRAY(ARRAY) AS arr4"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + ArrayList array1 = row.getAs(0, ArrayList.class); + assert array1.equals(Arrays.asList(1L, 2L, 3L)); + + ArrayList array2 = row.getAs(1, ArrayList.class); + assert array2.equals(Arrays.asList("a", "b")); + + ArrayList array3 = row.getAs(2, ArrayList.class); + assert array3.equals(Collections.singletonList(new Timestamp(31000000000L))); + + ArrayList array4 = row.getAs(3, ArrayList.class); + assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]")); + } finally { + TimeZone.setDefault(oldTimeZone); + } }, getSession()); } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 5ba449dd..df87666f 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -372,30 +372,35 @@ class RowSuite extends SNTestBase { test("getAs with structured array") { structuredTypeTest { - TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")) - - val query = - """SELECT - | [1,2,3]::ARRAY(NUMBER) AS arr1, - | ['a','b']::ARRAY(VARCHAR) AS arr2, - | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3, - | [[1,2]]::ARRAY(ARRAY) AS arr4 - |""".stripMargin - - val df = session.sql(query) - val row = df.collect()(0) - - val array1 = row.getAs[Array[Object]](0) - assert(array1 sameElements Array(1, 2, 3)) - - val array2 = row.getAs[Array[Object]](1) - assert(array2 sameElements Array("a", "b")) - - val array3 = row.getAs[Array[Object]](2) - assert(array3 sameElements Array(new Timestamp(31000000000L))) - - val array4 = row.getAs[Array[Object]](3) - assert(array4 sameElements Array("[\n 1,\n 2\n]")) + val oldTimeZone = TimeZone.getDefault + try { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")) + + val query = + """SELECT + | [1,2,3]::ARRAY(NUMBER) AS arr1, + | ['a','b']::ARRAY(VARCHAR) AS arr2, + | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3, + | [[1,2]]::ARRAY(ARRAY) AS arr4 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val array1 = row.getAs[Array[Object]](0) + assert(array1 sameElements Array(1, 2, 3)) + + val array2 = row.getAs[Array[Object]](1) + assert(array2 sameElements Array("a", "b")) + + val array3 = row.getAs[Array[Object]](2) + assert(array3 sameElements Array(new Timestamp(31000000000L))) + + val array4 = row.getAs[Array[Object]](3) + assert(array4 sameElements Array("[\n 1,\n 2\n]")) + } finally { + TimeZone.setDefault(oldTimeZone) + } } } From 8bbe2d5030c99647d10d1a2955956081de5191a3 Mon Sep 17 00:00:00 2001 From: Fabian Gonzalez Mendez Date: Mon, 2 Sep 2024 10:01:52 -0600 Subject: [PATCH 6/7] Remove type inference for Java variables --- src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index 938290f7..f8918292 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -571,7 +571,7 @@ public void getAsWithStructuredMap() { public void getAsWithStructuredArray() { structuredTypeTest( () -> { - var oldTimeZone = TimeZone.getDefault(); + TimeZone oldTimeZone = TimeZone.getDefault(); try { TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")); From 53dbc50b31da7dc608d3939518b78f0666627bc5 Mon Sep 17 00:00:00 2001 From: Fabian Gonzalez Mendez Date: Wed, 4 Sep 2024 15:23:20 -0600 Subject: [PATCH 7/7] Update the since tag to 1.15.0 --- src/main/java/com/snowflake/snowpark_java/Row.java | 2 +- src/main/scala/com/snowflake/snowpark/Row.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 40959927..74475e2d 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -443,7 +443,7 @@ public Row getObject(int index) { * @return the value at the specified column index cast to type {@code T}. * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}. * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. - * @since 1.14.0 + * @since 1.15.0 */ @SuppressWarnings("unchecked") public T getAs(int index, Class clazz) diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index 64a5db36..34eb9bbf 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -384,7 +384,7 @@ class Row protected (values: Array[Any]) extends Serializable { * @throws ClassCastException if the value at the given index cannot be cast to type `T`. * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. * @group getter - * @since 1.14.0 + * @since 1.15.0 */ def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = { classTag.runtimeClass match {