diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 0921a0d6..40959927 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -425,6 +425,64 @@ public Row getObject(int index) { return (Row) get(index); } + /** + * Returns the value at the specified column index and casts it to the desired type {@code T}. + * + *

Example: + * + *

{@code
+   * Row row = Row.create(1, "Alice", 95.5);
+   * row.getAs(0, Integer.class); // Returns 1 as an Int
+   * row.getAs(1, String.class); // Returns "Alice" as a String
+   * row.getAs(2, Double.class); // Returns 95.5 as a Double
+   * }
+ * + * @param index the zero-based column index within the row. + * @param clazz the {@code Class} object representing the type {@code T}. + * @param the expected type of the value at the specified column index. + * @return the value at the specified column index cast to type {@code T}. + * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}. + * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. + * @since 1.14.0 + */ + @SuppressWarnings("unchecked") + public T getAs(int index, Class clazz) + throws ClassCastException, ArrayIndexOutOfBoundsException { + if (isNullAt(index)) { + return (T) get(index); + } + + if (clazz == Byte.class) { + return (T) (Object) getByte(index); + } + + if (clazz == Double.class) { + return (T) (Object) getDouble(index); + } + + if (clazz == Float.class) { + return (T) (Object) getFloat(index); + } + + if (clazz == Integer.class) { + return (T) (Object) getInt(index); + } + + if (clazz == Long.class) { + return (T) (Object) getLong(index); + } + + if (clazz == Short.class) { + return (T) (Object) getShort(index); + } + + if (clazz == Variant.class) { + return (T) getVariant(index); + } + + return (T) get(index); + } + /** * Generates a string value to represent the content of this row. * diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index a1dc5aef..64a5db36 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable { getAs[Map[T, U]](index) } + /** + * Returns the value at the specified column index and casts it to the desired type `T`. + * + * Example: + * {{{ + * val row = Row(1, "Alice", 95.5) + * row.getAs[Int](0) // Returns 1 as an Int + * row.getAs[String](1) // Returns "Alice" as a String + * row.getAs[Double](2) // Returns 95.5 as a Double + * }}} + * + * @param index the zero-based column index within the row. + * @tparam T the expected type of the value at the specified column index. + * @return the value at the specified column index cast to type `T`. + * @throws ClassCastException if the value at the given index cannot be cast to type `T`. + * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. + * @group getter + * @since 1.14.0 + */ + def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = { + classTag.runtimeClass match { + case _ if isNullAt(index) => get(index).asInstanceOf[T] + case c if c == classOf[Byte] => getByte(index).asInstanceOf[T] + case c if c == classOf[Double] => getDouble(index).asInstanceOf[T] + case c if c == classOf[Float] => getFloat(index).asInstanceOf[T] + case c if c == classOf[Int] => getInt(index).asInstanceOf[T] + case c if c == classOf[Long] => getLong(index).asInstanceOf[T] + case c if c == classOf[Short] => getShort(index).asInstanceOf[T] + case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T] + case _ => get(index).asInstanceOf[T] + } + } + protected def convertValueToString(value: Any): String = value match { case null => "null" @@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable { .map(convertValueToString) .mkString("Row[", ",", "]") - private def getAs[T](index: Int): T = get(index).asInstanceOf[T] - - private def getAnyValAs[T <: AnyVal](index: Int): T = + private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T = if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null") - else getAs[T](index) - + else getAs[T](index)(classTag) } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index 0570c421..df15f438 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -1,5 +1,7 @@ package com.snowflake.snowpark_test; +import static org.junit.Assert.assertThrows; + import com.snowflake.snowpark_java.DataFrame; import com.snowflake.snowpark_java.Row; import com.snowflake.snowpark_java.types.*; @@ -429,4 +431,106 @@ public void testGetRow() { }, getSession()); } + + @Test + public void getAs() { + long milliseconds = System.currentTimeMillis(); + + StructType schema = + StructType.create( + new StructField("c01", DataTypes.BinaryType), + new StructField("c02", DataTypes.BooleanType), + new StructField("c03", DataTypes.ByteType), + new StructField("c04", DataTypes.DateType), + new StructField("c05", DataTypes.DoubleType), + new StructField("c06", DataTypes.FloatType), + new StructField("c07", DataTypes.GeographyType), + new StructField("c08", DataTypes.GeometryType), + new StructField("c09", DataTypes.IntegerType), + new StructField("c10", DataTypes.LongType), + new StructField("c11", DataTypes.ShortType), + new StructField("c12", DataTypes.StringType), + new StructField("c13", DataTypes.TimeType), + new StructField("c14", DataTypes.TimestampType), + new StructField("c15", DataTypes.VariantType)); + + Row[] data = { + Row.create( + new byte[] {1, 2}, + true, + Byte.MIN_VALUE, + Date.valueOf("2024-01-01"), + Double.MIN_VALUE, + Float.MIN_VALUE, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Integer.MIN_VALUE, + Long.MIN_VALUE, + Short.MIN_VALUE, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1)) + }; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2}); + assert row.getAs(1, Boolean.class); + assert row.getAs(2, Byte.class) == Byte.MIN_VALUE; + assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01")); + assert row.getAs(4, Double.class) == Double.MIN_VALUE; + assert row.getAs(5, Float.class) == Float.MIN_VALUE; + assert row.getAs(6, Geography.class) + .equals( + Geography.fromGeoJSON( + "{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(7, Geometry.class) + .equals( + Geometry.fromGeoJSON( + "{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(8, Integer.class) == Integer.MIN_VALUE; + assert row.getAs(9, Long.class) == Long.MIN_VALUE; + assert row.getAs(10, Short.class) == Short.MIN_VALUE; + assert row.getAs(11, String.class).equals("string"); + assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04")); + assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds)); + assert row.getAs(14, Variant.class).equals(new Variant(1)); + + Row finalRow = row; + assertThrows( + ClassCastException.class, + () -> { + Boolean b = finalRow.getAs(0, Boolean.class); + }); + assertThrows( + ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class)); + + data = + new Row[] { + Row.create( + null, null, null, null, null, null, null, null, null, null, null, null, null, null, + null) + }; + + df = getSession().createDataFrame(data, schema); + row = df.collect()[0]; + + assert row.getAs(0, byte[].class) == null; + assert row.getAs(1, Boolean.class) == null; + assert row.getAs(2, Byte.class) == null; + assert row.getAs(3, Date.class) == null; + assert row.getAs(4, Double.class) == null; + assert row.getAs(5, Float.class) == null; + assert row.getAs(6, Geography.class) == null; + assert row.getAs(7, Geometry.class) == null; + assert row.getAs(8, Integer.class) == null; + assert row.getAs(9, Long.class) == null; + assert row.getAs(10, Short.class) == null; + assert row.getAs(11, String.class) == null; + assert row.getAs(12, Time.class) == null; + assert row.getAs(13, Timestamp.class) == null; + assert row.getAs(14, Variant.class) == null; + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 54aba687..5491f6f1 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -1,9 +1,9 @@ package com.snowflake.snowpark_test -import com.snowflake.snowpark.types.{Geography, Geometry, Variant} +import com.snowflake.snowpark.types._ import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException} -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util @@ -233,6 +233,116 @@ class RowSuite extends SNTestBase { assert(err.message.matches(msg)) } + test("getAs") { + val milliseconds = System.currentTimeMillis() + + val schema = StructType( + Seq( + StructField("c01", BinaryType), + StructField("c02", BooleanType), + StructField("c03", ByteType), + StructField("c04", DateType), + StructField("c05", DoubleType), + StructField("c06", FloatType), + StructField("c07", GeographyType), + StructField("c08", GeometryType), + StructField("c09", IntegerType), + StructField("c10", LongType), + StructField("c11", ShortType), + StructField("c12", StringType), + StructField("c13", TimeType), + StructField("c14", TimestampType), + StructField("c15", VariantType))) + + var data = Seq( + Row( + Array[Byte](1, 9), + true, + Byte.MinValue, + Date.valueOf("2024-01-01"), + Double.MinValue, + Float.MinPositiveValue, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Int.MinValue, + Long.MinValue, + Short.MinValue, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1))) + + var df = session.createDataFrame(data, schema) + var row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9)) + assert(row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == Byte.MinValue) + assert(row.getAs[Date](3) == Date.valueOf("2024-01-01")) + assert(row.getAs[Double](4) == Double.MinValue) + assert(row.getAs[Float](5) == Float.MinPositiveValue) + assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{ + | "coordinates": [ + | 30, + | 10 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{ + | "coordinates": [ + | 2.000000000000000e+01, + | 4.000000000000000e+01 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Int](8) == Int.MinValue) + assert(row.getAs[Long](9) == Long.MinValue) + assert(row.getAs[Short](10) == Short.MinValue) + assert(row.getAs[String](11) == "string") + assert(row.getAs[Time](12) == Time.valueOf("16:23:04")) + assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds)) + assert(row.getAs[Variant](14) == new Variant(1)) + assertThrows[ClassCastException](row.getAs[Boolean](0)) + assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1)) + + data = Seq( + Row( + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null)) + + df = session.createDataFrame(data, schema) + row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) == null) + assert(!row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == 0) + assert(row.getAs[Date](3) == null) + assert(row.getAs[Double](4) == 0) + assert(row.getAs[Float](5) == 0) + assert(row.getAs[Geography](6) == null) + assert(row.getAs[Geometry](7) == null) + assert(row.getAs[Int](8) == 0) + assert(row.getAs[Long](9) == 0) + assert(row.getAs[Short](10) == 0) + assert(row.getAs[String](11) == null) + assert(row.getAs[Time](12) == null) + assert(row.getAs[Timestamp](13) == null) + assert(row.getAs[Variant](14) == null) + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3)