diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 0921a0d6..74475e2d 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -425,6 +425,64 @@ public Row getObject(int index) { return (Row) get(index); } + /** + * Returns the value at the specified column index and casts it to the desired type {@code T}. + * + *

Example: + * + *

{@code
+   * Row row = Row.create(1, "Alice", 95.5);
+   * row.getAs(0, Integer.class); // Returns 1 as an Int
+   * row.getAs(1, String.class); // Returns "Alice" as a String
+   * row.getAs(2, Double.class); // Returns 95.5 as a Double
+   * }
+ * + * @param index the zero-based column index within the row. + * @param clazz the {@code Class} object representing the type {@code T}. + * @param the expected type of the value at the specified column index. + * @return the value at the specified column index cast to type {@code T}. + * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}. + * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. + * @since 1.15.0 + */ + @SuppressWarnings("unchecked") + public T getAs(int index, Class clazz) + throws ClassCastException, ArrayIndexOutOfBoundsException { + if (isNullAt(index)) { + return (T) get(index); + } + + if (clazz == Byte.class) { + return (T) (Object) getByte(index); + } + + if (clazz == Double.class) { + return (T) (Object) getDouble(index); + } + + if (clazz == Float.class) { + return (T) (Object) getFloat(index); + } + + if (clazz == Integer.class) { + return (T) (Object) getInt(index); + } + + if (clazz == Long.class) { + return (T) (Object) getLong(index); + } + + if (clazz == Short.class) { + return (T) (Object) getShort(index); + } + + if (clazz == Variant.class) { + return (T) getVariant(index); + } + + return (T) get(index); + } + /** * Generates a string value to represent the content of this row. * diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index a1dc5aef..34eb9bbf 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable { getAs[Map[T, U]](index) } + /** + * Returns the value at the specified column index and casts it to the desired type `T`. + * + * Example: + * {{{ + * val row = Row(1, "Alice", 95.5) + * row.getAs[Int](0) // Returns 1 as an Int + * row.getAs[String](1) // Returns "Alice" as a String + * row.getAs[Double](2) // Returns 95.5 as a Double + * }}} + * + * @param index the zero-based column index within the row. + * @tparam T the expected type of the value at the specified column index. + * @return the value at the specified column index cast to type `T`. + * @throws ClassCastException if the value at the given index cannot be cast to type `T`. + * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. + * @group getter + * @since 1.15.0 + */ + def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = { + classTag.runtimeClass match { + case _ if isNullAt(index) => get(index).asInstanceOf[T] + case c if c == classOf[Byte] => getByte(index).asInstanceOf[T] + case c if c == classOf[Double] => getDouble(index).asInstanceOf[T] + case c if c == classOf[Float] => getFloat(index).asInstanceOf[T] + case c if c == classOf[Int] => getInt(index).asInstanceOf[T] + case c if c == classOf[Long] => getLong(index).asInstanceOf[T] + case c if c == classOf[Short] => getShort(index).asInstanceOf[T] + case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T] + case _ => get(index).asInstanceOf[T] + } + } + protected def convertValueToString(value: Any): String = value match { case null => "null" @@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable { .map(convertValueToString) .mkString("Row[", ",", "]") - private def getAs[T](index: Int): T = get(index).asInstanceOf[T] - - private def getAnyValAs[T <: AnyVal](index: Int): T = + private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T = if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null") - else getAs[T](index) - + else getAs[T](index)(classTag) } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index 0570c421..f8918292 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -1,5 +1,7 @@ package com.snowflake.snowpark_test; +import static org.junit.Assert.assertThrows; + import com.snowflake.snowpark_java.DataFrame; import com.snowflake.snowpark_java.Row; import com.snowflake.snowpark_java.types.*; @@ -7,10 +9,13 @@ import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.TimeZone; import org.junit.Test; public class JavaRowSuite extends TestBase { @@ -429,4 +434,172 @@ public void testGetRow() { }, getSession()); } + + @Test + public void getAs() { + long milliseconds = System.currentTimeMillis(); + + StructType schema = + StructType.create( + new StructField("c01", DataTypes.BinaryType), + new StructField("c02", DataTypes.BooleanType), + new StructField("c03", DataTypes.ByteType), + new StructField("c04", DataTypes.DateType), + new StructField("c05", DataTypes.DoubleType), + new StructField("c06", DataTypes.FloatType), + new StructField("c07", DataTypes.GeographyType), + new StructField("c08", DataTypes.GeometryType), + new StructField("c09", DataTypes.IntegerType), + new StructField("c10", DataTypes.LongType), + new StructField("c11", DataTypes.ShortType), + new StructField("c12", DataTypes.StringType), + new StructField("c13", DataTypes.TimeType), + new StructField("c14", DataTypes.TimestampType), + new StructField("c15", DataTypes.VariantType)); + + Row[] data = { + Row.create( + new byte[] {1, 2}, + true, + Byte.MIN_VALUE, + Date.valueOf("2024-01-01"), + Double.MIN_VALUE, + Float.MIN_VALUE, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Integer.MIN_VALUE, + Long.MIN_VALUE, + Short.MIN_VALUE, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1)) + }; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2}); + assert row.getAs(1, Boolean.class); + assert row.getAs(2, Byte.class) == Byte.MIN_VALUE; + assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01")); + assert row.getAs(4, Double.class) == Double.MIN_VALUE; + assert row.getAs(5, Float.class) == Float.MIN_VALUE; + assert row.getAs(6, Geography.class) + .equals( + Geography.fromGeoJSON( + "{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(7, Geometry.class) + .equals( + Geometry.fromGeoJSON( + "{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(8, Integer.class) == Integer.MIN_VALUE; + assert row.getAs(9, Long.class) == Long.MIN_VALUE; + assert row.getAs(10, Short.class) == Short.MIN_VALUE; + assert row.getAs(11, String.class).equals("string"); + assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04")); + assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds)); + assert row.getAs(14, Variant.class).equals(new Variant(1)); + + Row finalRow = row; + assertThrows( + ClassCastException.class, + () -> { + Boolean b = finalRow.getAs(0, Boolean.class); + }); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class)); + + data = + new Row[] { + Row.create( + null, null, null, null, null, null, null, null, null, null, null, null, null, null, + null) + }; + + df = getSession().createDataFrame(data, schema); + row = df.collect()[0]; + + assert row.getAs(0, byte[].class) == null; + assert row.getAs(1, Boolean.class) == null; + assert row.getAs(2, Byte.class) == null; + assert row.getAs(3, Date.class) == null; + assert row.getAs(4, Double.class) == null; + assert row.getAs(5, Float.class) == null; + assert row.getAs(6, Geography.class) == null; + assert row.getAs(7, Geometry.class) == null; + assert row.getAs(8, Integer.class) == null; + assert row.getAs(9, Long.class) == null; + assert row.getAs(10, Short.class) == null; + assert row.getAs(11, String.class) == null; + assert row.getAs(12, Time.class) == null; + assert row.getAs(13, Timestamp.class) == null; + assert row.getAs(14, Variant.class) == null; + } + + @Test + public void getAsWithStructuredMap() { + structuredTypeTest( + () -> { + String query = + "SELECT " + + "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1," + + "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2," + + "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + Map map1 = row.getAs(0, Map.class); + assert (Long) map1.get("a") == 1L; + assert (Long) map1.get("b") == 2L; + + Map map2 = row.getAs(1, Map.class); + assert map2.get(1L).equals("a"); + assert map2.get(2L).equals("b"); + + Map map3 = row.getAs(2, Map.class); + Map map3ExpectedInnerMap = new HashMap<>(); + map3ExpectedInnerMap.put("a", 1L); + map3ExpectedInnerMap.put("b", 2L); + assert map3.get(1L).equals(map3ExpectedInnerMap); + assert map3.get(2L).equals(Collections.singletonMap("c", 3L)); + }, + getSession()); + } + + @Test + public void getAsWithStructuredArray() { + structuredTypeTest( + () -> { + TimeZone oldTimeZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")); + + String query = + "SELECT " + + "[1,2,3]::ARRAY(NUMBER) AS arr1," + + "['a','b']::ARRAY(VARCHAR) AS arr2," + + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3," + + "[[1,2]]::ARRAY(ARRAY) AS arr4"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + ArrayList array1 = row.getAs(0, ArrayList.class); + assert array1.equals(Arrays.asList(1L, 2L, 3L)); + + ArrayList array2 = row.getAs(1, ArrayList.class); + assert array2.equals(Arrays.asList("a", "b")); + + ArrayList array3 = row.getAs(2, ArrayList.class); + assert array3.equals(Collections.singletonList(new Timestamp(31000000000L))); + + ArrayList array4 = row.getAs(3, ArrayList.class); + assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]")); + } finally { + TimeZone.setDefault(oldTimeZone); + } + }, + getSession()); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 54aba687..df87666f 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -1,11 +1,12 @@ package com.snowflake.snowpark_test -import com.snowflake.snowpark.types.{Geography, Geometry, Variant} +import com.snowflake.snowpark.types._ import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException} -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util +import java.util.TimeZone class RowSuite extends SNTestBase { @@ -233,6 +234,176 @@ class RowSuite extends SNTestBase { assert(err.message.matches(msg)) } + test("getAs") { + val milliseconds = System.currentTimeMillis() + + val schema = StructType( + Seq( + StructField("c01", BinaryType), + StructField("c02", BooleanType), + StructField("c03", ByteType), + StructField("c04", DateType), + StructField("c05", DoubleType), + StructField("c06", FloatType), + StructField("c07", GeographyType), + StructField("c08", GeometryType), + StructField("c09", IntegerType), + StructField("c10", LongType), + StructField("c11", ShortType), + StructField("c12", StringType), + StructField("c13", TimeType), + StructField("c14", TimestampType), + StructField("c15", VariantType))) + + var data = Seq( + Row( + Array[Byte](1, 9), + true, + Byte.MinValue, + Date.valueOf("2024-01-01"), + Double.MinValue, + Float.MinPositiveValue, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Int.MinValue, + Long.MinValue, + Short.MinValue, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1))) + + var df = session.createDataFrame(data, schema) + var row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9)) + assert(row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == Byte.MinValue) + assert(row.getAs[Date](3) == Date.valueOf("2024-01-01")) + assert(row.getAs[Double](4) == Double.MinValue) + assert(row.getAs[Float](5) == Float.MinPositiveValue) + assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{ + | "coordinates": [ + | 30, + | 10 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{ + | "coordinates": [ + | 2.000000000000000e+01, + | 4.000000000000000e+01 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Int](8) == Int.MinValue) + assert(row.getAs[Long](9) == Long.MinValue) + assert(row.getAs[Short](10) == Short.MinValue) + assert(row.getAs[String](11) == "string") + assert(row.getAs[Time](12) == Time.valueOf("16:23:04")) + assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds)) + assert(row.getAs[Variant](14) == new Variant(1)) + assertThrows[ClassCastException](row.getAs[Boolean](0)) + assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1)) + + data = Seq( + Row( + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null)) + + df = session.createDataFrame(data, schema) + row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) == null) + assert(!row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == 0) + assert(row.getAs[Date](3) == null) + assert(row.getAs[Double](4) == 0) + assert(row.getAs[Float](5) == 0) + assert(row.getAs[Geography](6) == null) + assert(row.getAs[Geometry](7) == null) + assert(row.getAs[Int](8) == 0) + assert(row.getAs[Long](9) == 0) + assert(row.getAs[Short](10) == 0) + assert(row.getAs[String](11) == null) + assert(row.getAs[Time](12) == null) + assert(row.getAs[Timestamp](13) == null) + assert(row.getAs[Variant](14) == null) + } + + test("getAs with structured map") { + structuredTypeTest { + val query = + """SELECT + | {'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1, + | {'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2, + | {'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val map1 = row.getAs[Map[String, Long]](0) + assert(map1("a") == 1L) + assert(map1("b") == 2L) + + val map2 = row.getAs[Map[Long, String]](1) + assert(map2(1) == "a") + assert(map2(2) == "b") + + val map3 = row.getAs[Map[Long, Map[String, Long]]](2) + assert(map3(1) == Map("a" -> 1, "b" -> 2)) + assert(map3(2) == Map("c" -> 3)) + } + } + + test("getAs with structured array") { + structuredTypeTest { + val oldTimeZone = TimeZone.getDefault + try { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")) + + val query = + """SELECT + | [1,2,3]::ARRAY(NUMBER) AS arr1, + | ['a','b']::ARRAY(VARCHAR) AS arr2, + | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3, + | [[1,2]]::ARRAY(ARRAY) AS arr4 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val array1 = row.getAs[Array[Object]](0) + assert(array1 sameElements Array(1, 2, 3)) + + val array2 = row.getAs[Array[Object]](1) + assert(array2 sameElements Array("a", "b")) + + val array3 = row.getAs[Array[Object]](2) + assert(array3 sameElements Array(new Timestamp(31000000000L))) + + val array4 = row.getAs[Array[Object]](3) + assert(array4 sameElements Array("[\n 1,\n 2\n]")) + } finally { + TimeZone.setDefault(oldTimeZone) + } + } + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3)