diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java
index 0921a0d6..40959927 100644
--- a/src/main/java/com/snowflake/snowpark_java/Row.java
+++ b/src/main/java/com/snowflake/snowpark_java/Row.java
@@ -425,6 +425,64 @@ public Row getObject(int index) {
return (Row) get(index);
}
+ /**
+ * Returns the value at the specified column index and casts it to the desired type {@code T}.
+ *
+ *
Example:
+ *
+ *
{@code
+ * Row row = Row.create(1, "Alice", 95.5);
+ * row.getAs(0, Integer.class); // Returns 1 as an Int
+ * row.getAs(1, String.class); // Returns "Alice" as a String
+ * row.getAs(2, Double.class); // Returns 95.5 as a Double
+ * }
+ *
+ * @param index the zero-based column index within the row.
+ * @param clazz the {@code Class} object representing the type {@code T}.
+ * @param the expected type of the value at the specified column index.
+ * @return the value at the specified column index cast to type {@code T}.
+ * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}.
+ * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
+ * @since 1.14.0
+ */
+ @SuppressWarnings("unchecked")
+ public T getAs(int index, Class clazz)
+ throws ClassCastException, ArrayIndexOutOfBoundsException {
+ if (isNullAt(index)) {
+ return (T) get(index);
+ }
+
+ if (clazz == Byte.class) {
+ return (T) (Object) getByte(index);
+ }
+
+ if (clazz == Double.class) {
+ return (T) (Object) getDouble(index);
+ }
+
+ if (clazz == Float.class) {
+ return (T) (Object) getFloat(index);
+ }
+
+ if (clazz == Integer.class) {
+ return (T) (Object) getInt(index);
+ }
+
+ if (clazz == Long.class) {
+ return (T) (Object) getLong(index);
+ }
+
+ if (clazz == Short.class) {
+ return (T) (Object) getShort(index);
+ }
+
+ if (clazz == Variant.class) {
+ return (T) getVariant(index);
+ }
+
+ return (T) get(index);
+ }
+
/**
* Generates a string value to represent the content of this row.
*
diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala
index a1dc5aef..64a5db36 100644
--- a/src/main/scala/com/snowflake/snowpark/Row.scala
+++ b/src/main/scala/com/snowflake/snowpark/Row.scala
@@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}
+ /**
+ * Returns the value at the specified column index and casts it to the desired type `T`.
+ *
+ * Example:
+ * {{{
+ * val row = Row(1, "Alice", 95.5)
+ * row.getAs[Int](0) // Returns 1 as an Int
+ * row.getAs[String](1) // Returns "Alice" as a String
+ * row.getAs[Double](2) // Returns 95.5 as a Double
+ * }}}
+ *
+ * @param index the zero-based column index within the row.
+ * @tparam T the expected type of the value at the specified column index.
+ * @return the value at the specified column index cast to type `T`.
+ * @throws ClassCastException if the value at the given index cannot be cast to type `T`.
+ * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
+ * @group getter
+ * @since 1.14.0
+ */
+ def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = {
+ classTag.runtimeClass match {
+ case _ if isNullAt(index) => get(index).asInstanceOf[T]
+ case c if c == classOf[Byte] => getByte(index).asInstanceOf[T]
+ case c if c == classOf[Double] => getDouble(index).asInstanceOf[T]
+ case c if c == classOf[Float] => getFloat(index).asInstanceOf[T]
+ case c if c == classOf[Int] => getInt(index).asInstanceOf[T]
+ case c if c == classOf[Long] => getLong(index).asInstanceOf[T]
+ case c if c == classOf[Short] => getShort(index).asInstanceOf[T]
+ case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T]
+ case _ => get(index).asInstanceOf[T]
+ }
+ }
+
protected def convertValueToString(value: Any): String =
value match {
case null => "null"
@@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable {
.map(convertValueToString)
.mkString("Row[", ",", "]")
- private def getAs[T](index: Int): T = get(index).asInstanceOf[T]
-
- private def getAnyValAs[T <: AnyVal](index: Int): T =
+ private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T =
if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null")
- else getAs[T](index)
-
+ else getAs[T](index)(classTag)
}
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
index 0570c421..df15f438 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
@@ -1,5 +1,7 @@
package com.snowflake.snowpark_test;
+import static org.junit.Assert.assertThrows;
+
import com.snowflake.snowpark_java.DataFrame;
import com.snowflake.snowpark_java.Row;
import com.snowflake.snowpark_java.types.*;
@@ -429,4 +431,106 @@ public void testGetRow() {
},
getSession());
}
+
+ @Test
+ public void getAs() {
+ long milliseconds = System.currentTimeMillis();
+
+ StructType schema =
+ StructType.create(
+ new StructField("c01", DataTypes.BinaryType),
+ new StructField("c02", DataTypes.BooleanType),
+ new StructField("c03", DataTypes.ByteType),
+ new StructField("c04", DataTypes.DateType),
+ new StructField("c05", DataTypes.DoubleType),
+ new StructField("c06", DataTypes.FloatType),
+ new StructField("c07", DataTypes.GeographyType),
+ new StructField("c08", DataTypes.GeometryType),
+ new StructField("c09", DataTypes.IntegerType),
+ new StructField("c10", DataTypes.LongType),
+ new StructField("c11", DataTypes.ShortType),
+ new StructField("c12", DataTypes.StringType),
+ new StructField("c13", DataTypes.TimeType),
+ new StructField("c14", DataTypes.TimestampType),
+ new StructField("c15", DataTypes.VariantType));
+
+ Row[] data = {
+ Row.create(
+ new byte[] {1, 2},
+ true,
+ Byte.MIN_VALUE,
+ Date.valueOf("2024-01-01"),
+ Double.MIN_VALUE,
+ Float.MIN_VALUE,
+ Geography.fromGeoJSON("POINT(30 10)"),
+ Geometry.fromGeoJSON("POINT(20 40)"),
+ Integer.MIN_VALUE,
+ Long.MIN_VALUE,
+ Short.MIN_VALUE,
+ "string",
+ Time.valueOf("16:23:04"),
+ new Timestamp(milliseconds),
+ new Variant(1))
+ };
+
+ DataFrame df = getSession().createDataFrame(data, schema);
+ Row row = df.collect()[0];
+
+ assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2});
+ assert row.getAs(1, Boolean.class);
+ assert row.getAs(2, Byte.class) == Byte.MIN_VALUE;
+ assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01"));
+ assert row.getAs(4, Double.class) == Double.MIN_VALUE;
+ assert row.getAs(5, Float.class) == Float.MIN_VALUE;
+ assert row.getAs(6, Geography.class)
+ .equals(
+ Geography.fromGeoJSON(
+ "{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}"));
+ assert row.getAs(7, Geometry.class)
+ .equals(
+ Geometry.fromGeoJSON(
+ "{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}"));
+ assert row.getAs(8, Integer.class) == Integer.MIN_VALUE;
+ assert row.getAs(9, Long.class) == Long.MIN_VALUE;
+ assert row.getAs(10, Short.class) == Short.MIN_VALUE;
+ assert row.getAs(11, String.class).equals("string");
+ assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04"));
+ assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds));
+ assert row.getAs(14, Variant.class).equals(new Variant(1));
+
+ Row finalRow = row;
+ assertThrows(
+ ClassCastException.class,
+ () -> {
+ Boolean b = finalRow.getAs(0, Boolean.class);
+ });
+ assertThrows(
+ ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class));
+
+ data =
+ new Row[] {
+ Row.create(
+ null, null, null, null, null, null, null, null, null, null, null, null, null, null,
+ null)
+ };
+
+ df = getSession().createDataFrame(data, schema);
+ row = df.collect()[0];
+
+ assert row.getAs(0, byte[].class) == null;
+ assert row.getAs(1, Boolean.class) == null;
+ assert row.getAs(2, Byte.class) == null;
+ assert row.getAs(3, Date.class) == null;
+ assert row.getAs(4, Double.class) == null;
+ assert row.getAs(5, Float.class) == null;
+ assert row.getAs(6, Geography.class) == null;
+ assert row.getAs(7, Geometry.class) == null;
+ assert row.getAs(8, Integer.class) == null;
+ assert row.getAs(9, Long.class) == null;
+ assert row.getAs(10, Short.class) == null;
+ assert row.getAs(11, String.class) == null;
+ assert row.getAs(12, Time.class) == null;
+ assert row.getAs(13, Timestamp.class) == null;
+ assert row.getAs(14, Variant.class) == null;
+ }
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
index 54aba687..5491f6f1 100644
--- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
@@ -1,9 +1,9 @@
package com.snowflake.snowpark_test
-import com.snowflake.snowpark.types.{Geography, Geometry, Variant}
+import com.snowflake.snowpark.types._
import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException}
-import java.sql.{Date, Timestamp}
+import java.sql.{Date, Time, Timestamp}
import java.time.{Instant, LocalDate}
import java.util
@@ -233,6 +233,116 @@ class RowSuite extends SNTestBase {
assert(err.message.matches(msg))
}
+ test("getAs") {
+ val milliseconds = System.currentTimeMillis()
+
+ val schema = StructType(
+ Seq(
+ StructField("c01", BinaryType),
+ StructField("c02", BooleanType),
+ StructField("c03", ByteType),
+ StructField("c04", DateType),
+ StructField("c05", DoubleType),
+ StructField("c06", FloatType),
+ StructField("c07", GeographyType),
+ StructField("c08", GeometryType),
+ StructField("c09", IntegerType),
+ StructField("c10", LongType),
+ StructField("c11", ShortType),
+ StructField("c12", StringType),
+ StructField("c13", TimeType),
+ StructField("c14", TimestampType),
+ StructField("c15", VariantType)))
+
+ var data = Seq(
+ Row(
+ Array[Byte](1, 9),
+ true,
+ Byte.MinValue,
+ Date.valueOf("2024-01-01"),
+ Double.MinValue,
+ Float.MinPositiveValue,
+ Geography.fromGeoJSON("POINT(30 10)"),
+ Geometry.fromGeoJSON("POINT(20 40)"),
+ Int.MinValue,
+ Long.MinValue,
+ Short.MinValue,
+ "string",
+ Time.valueOf("16:23:04"),
+ new Timestamp(milliseconds),
+ new Variant(1)))
+
+ var df = session.createDataFrame(data, schema)
+ var row = df.collect()(0)
+
+ assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9))
+ assert(row.getAs[Boolean](1))
+ assert(row.getAs[Byte](2) == Byte.MinValue)
+ assert(row.getAs[Date](3) == Date.valueOf("2024-01-01"))
+ assert(row.getAs[Double](4) == Double.MinValue)
+ assert(row.getAs[Float](5) == Float.MinPositiveValue)
+ assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{
+ | "coordinates": [
+ | 30,
+ | 10
+ | ],
+ | "type": "Point"
+ |}""".stripMargin))
+ assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{
+ | "coordinates": [
+ | 2.000000000000000e+01,
+ | 4.000000000000000e+01
+ | ],
+ | "type": "Point"
+ |}""".stripMargin))
+ assert(row.getAs[Int](8) == Int.MinValue)
+ assert(row.getAs[Long](9) == Long.MinValue)
+ assert(row.getAs[Short](10) == Short.MinValue)
+ assert(row.getAs[String](11) == "string")
+ assert(row.getAs[Time](12) == Time.valueOf("16:23:04"))
+ assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds))
+ assert(row.getAs[Variant](14) == new Variant(1))
+ assertThrows[ClassCastException](row.getAs[Boolean](0))
+ assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1))
+
+ data = Seq(
+ Row(
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null))
+
+ df = session.createDataFrame(data, schema)
+ row = df.collect()(0)
+
+ assert(row.getAs[Array[Byte]](0) == null)
+ assert(!row.getAs[Boolean](1))
+ assert(row.getAs[Byte](2) == 0)
+ assert(row.getAs[Date](3) == null)
+ assert(row.getAs[Double](4) == 0)
+ assert(row.getAs[Float](5) == 0)
+ assert(row.getAs[Geography](6) == null)
+ assert(row.getAs[Geometry](7) == null)
+ assert(row.getAs[Int](8) == 0)
+ assert(row.getAs[Long](9) == 0)
+ assert(row.getAs[Short](10) == 0)
+ assert(row.getAs[String](11) == null)
+ assert(row.getAs[Time](12) == null)
+ assert(row.getAs[Timestamp](13) == null)
+ assert(row.getAs[Variant](14) == null)
+ }
+
test("hashCode") {
val row1 = Row(1, 2, 3)
val row2 = Row("str", null, 3)