diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java
index 0921a0d6..74475e2d 100644
--- a/src/main/java/com/snowflake/snowpark_java/Row.java
+++ b/src/main/java/com/snowflake/snowpark_java/Row.java
@@ -425,6 +425,64 @@ public Row getObject(int index) {
return (Row) get(index);
}
+ /**
+ * Returns the value at the specified column index and casts it to the desired type {@code T}.
+ *
+ *
Example:
+ *
+ *
{@code
+ * Row row = Row.create(1, "Alice", 95.5);
+ * row.getAs(0, Integer.class); // Returns 1 as an Int
+ * row.getAs(1, String.class); // Returns "Alice" as a String
+ * row.getAs(2, Double.class); // Returns 95.5 as a Double
+ * }
+ *
+ * @param index the zero-based column index within the row.
+ * @param clazz the {@code Class} object representing the type {@code T}.
+ * @param the expected type of the value at the specified column index.
+ * @return the value at the specified column index cast to type {@code T}.
+ * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}.
+ * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
+ * @since 1.15.0
+ */
+ @SuppressWarnings("unchecked")
+ public T getAs(int index, Class clazz)
+ throws ClassCastException, ArrayIndexOutOfBoundsException {
+ if (isNullAt(index)) {
+ return (T) get(index);
+ }
+
+ if (clazz == Byte.class) {
+ return (T) (Object) getByte(index);
+ }
+
+ if (clazz == Double.class) {
+ return (T) (Object) getDouble(index);
+ }
+
+ if (clazz == Float.class) {
+ return (T) (Object) getFloat(index);
+ }
+
+ if (clazz == Integer.class) {
+ return (T) (Object) getInt(index);
+ }
+
+ if (clazz == Long.class) {
+ return (T) (Object) getLong(index);
+ }
+
+ if (clazz == Short.class) {
+ return (T) (Object) getShort(index);
+ }
+
+ if (clazz == Variant.class) {
+ return (T) getVariant(index);
+ }
+
+ return (T) get(index);
+ }
+
/**
* Generates a string value to represent the content of this row.
*
diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala
index a1dc5aef..34eb9bbf 100644
--- a/src/main/scala/com/snowflake/snowpark/Row.scala
+++ b/src/main/scala/com/snowflake/snowpark/Row.scala
@@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}
+ /**
+ * Returns the value at the specified column index and casts it to the desired type `T`.
+ *
+ * Example:
+ * {{{
+ * val row = Row(1, "Alice", 95.5)
+ * row.getAs[Int](0) // Returns 1 as an Int
+ * row.getAs[String](1) // Returns "Alice" as a String
+ * row.getAs[Double](2) // Returns 95.5 as a Double
+ * }}}
+ *
+ * @param index the zero-based column index within the row.
+ * @tparam T the expected type of the value at the specified column index.
+ * @return the value at the specified column index cast to type `T`.
+ * @throws ClassCastException if the value at the given index cannot be cast to type `T`.
+ * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
+ * @group getter
+ * @since 1.15.0
+ */
+ def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = {
+ classTag.runtimeClass match {
+ case _ if isNullAt(index) => get(index).asInstanceOf[T]
+ case c if c == classOf[Byte] => getByte(index).asInstanceOf[T]
+ case c if c == classOf[Double] => getDouble(index).asInstanceOf[T]
+ case c if c == classOf[Float] => getFloat(index).asInstanceOf[T]
+ case c if c == classOf[Int] => getInt(index).asInstanceOf[T]
+ case c if c == classOf[Long] => getLong(index).asInstanceOf[T]
+ case c if c == classOf[Short] => getShort(index).asInstanceOf[T]
+ case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T]
+ case _ => get(index).asInstanceOf[T]
+ }
+ }
+
protected def convertValueToString(value: Any): String =
value match {
case null => "null"
@@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable {
.map(convertValueToString)
.mkString("Row[", ",", "]")
- private def getAs[T](index: Int): T = get(index).asInstanceOf[T]
-
- private def getAnyValAs[T <: AnyVal](index: Int): T =
+ private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T =
if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null")
- else getAs[T](index)
-
+ else getAs[T](index)(classTag)
}
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
index 0570c421..f8918292 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
@@ -1,5 +1,7 @@
package com.snowflake.snowpark_test;
+import static org.junit.Assert.assertThrows;
+
import com.snowflake.snowpark_java.DataFrame;
import com.snowflake.snowpark_java.Row;
import com.snowflake.snowpark_java.types.*;
@@ -7,10 +9,13 @@
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.TimeZone;
import org.junit.Test;
public class JavaRowSuite extends TestBase {
@@ -429,4 +434,172 @@ public void testGetRow() {
},
getSession());
}
+
+ @Test
+ public void getAs() {
+ long milliseconds = System.currentTimeMillis();
+
+ StructType schema =
+ StructType.create(
+ new StructField("c01", DataTypes.BinaryType),
+ new StructField("c02", DataTypes.BooleanType),
+ new StructField("c03", DataTypes.ByteType),
+ new StructField("c04", DataTypes.DateType),
+ new StructField("c05", DataTypes.DoubleType),
+ new StructField("c06", DataTypes.FloatType),
+ new StructField("c07", DataTypes.GeographyType),
+ new StructField("c08", DataTypes.GeometryType),
+ new StructField("c09", DataTypes.IntegerType),
+ new StructField("c10", DataTypes.LongType),
+ new StructField("c11", DataTypes.ShortType),
+ new StructField("c12", DataTypes.StringType),
+ new StructField("c13", DataTypes.TimeType),
+ new StructField("c14", DataTypes.TimestampType),
+ new StructField("c15", DataTypes.VariantType));
+
+ Row[] data = {
+ Row.create(
+ new byte[] {1, 2},
+ true,
+ Byte.MIN_VALUE,
+ Date.valueOf("2024-01-01"),
+ Double.MIN_VALUE,
+ Float.MIN_VALUE,
+ Geography.fromGeoJSON("POINT(30 10)"),
+ Geometry.fromGeoJSON("POINT(20 40)"),
+ Integer.MIN_VALUE,
+ Long.MIN_VALUE,
+ Short.MIN_VALUE,
+ "string",
+ Time.valueOf("16:23:04"),
+ new Timestamp(milliseconds),
+ new Variant(1))
+ };
+
+ DataFrame df = getSession().createDataFrame(data, schema);
+ Row row = df.collect()[0];
+
+ assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2});
+ assert row.getAs(1, Boolean.class);
+ assert row.getAs(2, Byte.class) == Byte.MIN_VALUE;
+ assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01"));
+ assert row.getAs(4, Double.class) == Double.MIN_VALUE;
+ assert row.getAs(5, Float.class) == Float.MIN_VALUE;
+ assert row.getAs(6, Geography.class)
+ .equals(
+ Geography.fromGeoJSON(
+ "{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}"));
+ assert row.getAs(7, Geometry.class)
+ .equals(
+ Geometry.fromGeoJSON(
+ "{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}"));
+ assert row.getAs(8, Integer.class) == Integer.MIN_VALUE;
+ assert row.getAs(9, Long.class) == Long.MIN_VALUE;
+ assert row.getAs(10, Short.class) == Short.MIN_VALUE;
+ assert row.getAs(11, String.class).equals("string");
+ assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04"));
+ assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds));
+ assert row.getAs(14, Variant.class).equals(new Variant(1));
+
+ Row finalRow = row;
+ assertThrows(
+ ClassCastException.class,
+ () -> {
+ Boolean b = finalRow.getAs(0, Boolean.class);
+ });
+ assertThrows(ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class));
+
+ data =
+ new Row[] {
+ Row.create(
+ null, null, null, null, null, null, null, null, null, null, null, null, null, null,
+ null)
+ };
+
+ df = getSession().createDataFrame(data, schema);
+ row = df.collect()[0];
+
+ assert row.getAs(0, byte[].class) == null;
+ assert row.getAs(1, Boolean.class) == null;
+ assert row.getAs(2, Byte.class) == null;
+ assert row.getAs(3, Date.class) == null;
+ assert row.getAs(4, Double.class) == null;
+ assert row.getAs(5, Float.class) == null;
+ assert row.getAs(6, Geography.class) == null;
+ assert row.getAs(7, Geometry.class) == null;
+ assert row.getAs(8, Integer.class) == null;
+ assert row.getAs(9, Long.class) == null;
+ assert row.getAs(10, Short.class) == null;
+ assert row.getAs(11, String.class) == null;
+ assert row.getAs(12, Time.class) == null;
+ assert row.getAs(13, Timestamp.class) == null;
+ assert row.getAs(14, Variant.class) == null;
+ }
+
+ @Test
+ public void getAsWithStructuredMap() {
+ structuredTypeTest(
+ () -> {
+ String query =
+ "SELECT "
+ + "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,"
+ + "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,"
+ + "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3";
+
+ DataFrame df = getSession().sql(query);
+ Row row = df.collect()[0];
+
+ Map, ?> map1 = row.getAs(0, Map.class);
+ assert (Long) map1.get("a") == 1L;
+ assert (Long) map1.get("b") == 2L;
+
+ Map, ?> map2 = row.getAs(1, Map.class);
+ assert map2.get(1L).equals("a");
+ assert map2.get(2L).equals("b");
+
+ Map, ?> map3 = row.getAs(2, Map.class);
+ Map map3ExpectedInnerMap = new HashMap<>();
+ map3ExpectedInnerMap.put("a", 1L);
+ map3ExpectedInnerMap.put("b", 2L);
+ assert map3.get(1L).equals(map3ExpectedInnerMap);
+ assert map3.get(2L).equals(Collections.singletonMap("c", 3L));
+ },
+ getSession());
+ }
+
+ @Test
+ public void getAsWithStructuredArray() {
+ structuredTypeTest(
+ () -> {
+ TimeZone oldTimeZone = TimeZone.getDefault();
+ try {
+ TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"));
+
+ String query =
+ "SELECT "
+ + "[1,2,3]::ARRAY(NUMBER) AS arr1,"
+ + "['a','b']::ARRAY(VARCHAR) AS arr2,"
+ + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,"
+ + "[[1,2]]::ARRAY(ARRAY) AS arr4";
+
+ DataFrame df = getSession().sql(query);
+ Row row = df.collect()[0];
+
+ ArrayList> array1 = row.getAs(0, ArrayList.class);
+ assert array1.equals(Arrays.asList(1L, 2L, 3L));
+
+ ArrayList> array2 = row.getAs(1, ArrayList.class);
+ assert array2.equals(Arrays.asList("a", "b"));
+
+ ArrayList> array3 = row.getAs(2, ArrayList.class);
+ assert array3.equals(Collections.singletonList(new Timestamp(31000000000L)));
+
+ ArrayList> array4 = row.getAs(3, ArrayList.class);
+ assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]"));
+ } finally {
+ TimeZone.setDefault(oldTimeZone);
+ }
+ },
+ getSession());
+ }
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
index 54aba687..df87666f 100644
--- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
@@ -1,11 +1,12 @@
package com.snowflake.snowpark_test
-import com.snowflake.snowpark.types.{Geography, Geometry, Variant}
+import com.snowflake.snowpark.types._
import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException}
-import java.sql.{Date, Timestamp}
+import java.sql.{Date, Time, Timestamp}
import java.time.{Instant, LocalDate}
import java.util
+import java.util.TimeZone
class RowSuite extends SNTestBase {
@@ -233,6 +234,176 @@ class RowSuite extends SNTestBase {
assert(err.message.matches(msg))
}
+ test("getAs") {
+ val milliseconds = System.currentTimeMillis()
+
+ val schema = StructType(
+ Seq(
+ StructField("c01", BinaryType),
+ StructField("c02", BooleanType),
+ StructField("c03", ByteType),
+ StructField("c04", DateType),
+ StructField("c05", DoubleType),
+ StructField("c06", FloatType),
+ StructField("c07", GeographyType),
+ StructField("c08", GeometryType),
+ StructField("c09", IntegerType),
+ StructField("c10", LongType),
+ StructField("c11", ShortType),
+ StructField("c12", StringType),
+ StructField("c13", TimeType),
+ StructField("c14", TimestampType),
+ StructField("c15", VariantType)))
+
+ var data = Seq(
+ Row(
+ Array[Byte](1, 9),
+ true,
+ Byte.MinValue,
+ Date.valueOf("2024-01-01"),
+ Double.MinValue,
+ Float.MinPositiveValue,
+ Geography.fromGeoJSON("POINT(30 10)"),
+ Geometry.fromGeoJSON("POINT(20 40)"),
+ Int.MinValue,
+ Long.MinValue,
+ Short.MinValue,
+ "string",
+ Time.valueOf("16:23:04"),
+ new Timestamp(milliseconds),
+ new Variant(1)))
+
+ var df = session.createDataFrame(data, schema)
+ var row = df.collect()(0)
+
+ assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9))
+ assert(row.getAs[Boolean](1))
+ assert(row.getAs[Byte](2) == Byte.MinValue)
+ assert(row.getAs[Date](3) == Date.valueOf("2024-01-01"))
+ assert(row.getAs[Double](4) == Double.MinValue)
+ assert(row.getAs[Float](5) == Float.MinPositiveValue)
+ assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{
+ | "coordinates": [
+ | 30,
+ | 10
+ | ],
+ | "type": "Point"
+ |}""".stripMargin))
+ assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{
+ | "coordinates": [
+ | 2.000000000000000e+01,
+ | 4.000000000000000e+01
+ | ],
+ | "type": "Point"
+ |}""".stripMargin))
+ assert(row.getAs[Int](8) == Int.MinValue)
+ assert(row.getAs[Long](9) == Long.MinValue)
+ assert(row.getAs[Short](10) == Short.MinValue)
+ assert(row.getAs[String](11) == "string")
+ assert(row.getAs[Time](12) == Time.valueOf("16:23:04"))
+ assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds))
+ assert(row.getAs[Variant](14) == new Variant(1))
+ assertThrows[ClassCastException](row.getAs[Boolean](0))
+ assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1))
+
+ data = Seq(
+ Row(
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null))
+
+ df = session.createDataFrame(data, schema)
+ row = df.collect()(0)
+
+ assert(row.getAs[Array[Byte]](0) == null)
+ assert(!row.getAs[Boolean](1))
+ assert(row.getAs[Byte](2) == 0)
+ assert(row.getAs[Date](3) == null)
+ assert(row.getAs[Double](4) == 0)
+ assert(row.getAs[Float](5) == 0)
+ assert(row.getAs[Geography](6) == null)
+ assert(row.getAs[Geometry](7) == null)
+ assert(row.getAs[Int](8) == 0)
+ assert(row.getAs[Long](9) == 0)
+ assert(row.getAs[Short](10) == 0)
+ assert(row.getAs[String](11) == null)
+ assert(row.getAs[Time](12) == null)
+ assert(row.getAs[Timestamp](13) == null)
+ assert(row.getAs[Variant](14) == null)
+ }
+
+ test("getAs with structured map") {
+ structuredTypeTest {
+ val query =
+ """SELECT
+ | {'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,
+ | {'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,
+ | {'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3
+ |""".stripMargin
+
+ val df = session.sql(query)
+ val row = df.collect()(0)
+
+ val map1 = row.getAs[Map[String, Long]](0)
+ assert(map1("a") == 1L)
+ assert(map1("b") == 2L)
+
+ val map2 = row.getAs[Map[Long, String]](1)
+ assert(map2(1) == "a")
+ assert(map2(2) == "b")
+
+ val map3 = row.getAs[Map[Long, Map[String, Long]]](2)
+ assert(map3(1) == Map("a" -> 1, "b" -> 2))
+ assert(map3(2) == Map("c" -> 3))
+ }
+ }
+
+ test("getAs with structured array") {
+ structuredTypeTest {
+ val oldTimeZone = TimeZone.getDefault
+ try {
+ TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"))
+
+ val query =
+ """SELECT
+ | [1,2,3]::ARRAY(NUMBER) AS arr1,
+ | ['a','b']::ARRAY(VARCHAR) AS arr2,
+ | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,
+ | [[1,2]]::ARRAY(ARRAY) AS arr4
+ |""".stripMargin
+
+ val df = session.sql(query)
+ val row = df.collect()(0)
+
+ val array1 = row.getAs[Array[Object]](0)
+ assert(array1 sameElements Array(1, 2, 3))
+
+ val array2 = row.getAs[Array[Object]](1)
+ assert(array2 sameElements Array("a", "b"))
+
+ val array3 = row.getAs[Array[Object]](2)
+ assert(array3 sameElements Array(new Timestamp(31000000000L)))
+
+ val array4 = row.getAs[Array[Object]](3)
+ assert(array4 sameElements Array("[\n 1,\n 2\n]"))
+ } finally {
+ TimeZone.setDefault(oldTimeZone)
+ }
+ }
+ }
+
test("hashCode") {
val row1 = Row(1, 2, 3)
val row2 = Row("str", null, 3)