Skip to content

Commit

Permalink
Add support for com.snowflake.snowpark.Row.getAs function
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-fgonzalezmendez committed Aug 30, 2024
1 parent d1058da commit f7c9f9e
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 7 deletions.
58 changes: 58 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,64 @@ public Row getObject(int index) {
return (Row) get(index);
}

/**
* Returns the value at the specified column index and casts it to the desired type {@code T}.
*
* <p>Example:
*
* <pre>{@code
* Row row = Row.create(1, "Alice", 95.5);
* row.getAs(0, Integer.class); // Returns 1 as an Int
* row.getAs(1, String.class); // Returns "Alice" as a String
* row.getAs(2, Double.class); // Returns 95.5 as a Double
* }</pre>
*
* @param index the zero-based column index within the row.
* @param clazz the {@code Class} object representing the type {@code T}.
* @param <T> the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type {@code T}.
* @throws ClassCastException if the value at the given index cannot be cast to type {@code T}.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @since 1.14.0
*/
@SuppressWarnings("unchecked")
public <T> T getAs(int index, Class<T> clazz)
throws ClassCastException, ArrayIndexOutOfBoundsException {
if (isNullAt(index)) {
return (T) get(index);
}

if (clazz == Byte.class) {
return (T) (Object) getByte(index);
}

if (clazz == Double.class) {
return (T) (Object) getDouble(index);
}

if (clazz == Float.class) {
return (T) (Object) getFloat(index);
}

if (clazz == Integer.class) {
return (T) (Object) getInt(index);
}

if (clazz == Long.class) {
return (T) (Object) getLong(index);
}

if (clazz == Short.class) {
return (T) (Object) getShort(index);
}

if (clazz == Variant.class) {
return (T) getVariant(index);
}

return (T) get(index);
}

/**
* Generates a string value to represent the content of this row.
*
Expand Down
40 changes: 35 additions & 5 deletions src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}

/**
* Returns the value at the specified column index and casts it to the desired type `T`.
*
* Example:
* {{{
* val row = Row(1, "Alice", 95.5)
* row.getAs[Int](0) // Returns 1 as an Int
* row.getAs[String](1) // Returns "Alice" as a String
* row.getAs[Double](2) // Returns 95.5 as a Double
* }}}
*
* @param index the zero-based column index within the row.
* @tparam T the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type `T`.
* @throws ClassCastException if the value at the given index cannot be cast to type `T`.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @group getter
* @since 1.14.0
*/
def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = {
classTag.runtimeClass match {
case _ if isNullAt(index) => get(index).asInstanceOf[T]
case c if c == classOf[Byte] => getByte(index).asInstanceOf[T]
case c if c == classOf[Double] => getDouble(index).asInstanceOf[T]
case c if c == classOf[Float] => getFloat(index).asInstanceOf[T]
case c if c == classOf[Int] => getInt(index).asInstanceOf[T]
case c if c == classOf[Long] => getLong(index).asInstanceOf[T]
case c if c == classOf[Short] => getShort(index).asInstanceOf[T]
case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T]
case _ => get(index).asInstanceOf[T]
}
}

protected def convertValueToString(value: Any): String =
value match {
case null => "null"
Expand Down Expand Up @@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable {
.map(convertValueToString)
.mkString("Row[", ",", "]")

private def getAs[T](index: Int): T = get(index).asInstanceOf[T]

private def getAnyValAs[T <: AnyVal](index: Int): T =
private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T =
if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null")
else getAs[T](index)

else getAs[T](index)(classTag)
}
104 changes: 104 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.snowflake.snowpark_test;

import static org.junit.Assert.assertThrows;

import com.snowflake.snowpark_java.DataFrame;
import com.snowflake.snowpark_java.Row;
import com.snowflake.snowpark_java.types.*;
Expand Down Expand Up @@ -429,4 +431,106 @@ public void testGetRow() {
},
getSession());
}

@Test
public void getAs() {
long milliseconds = System.currentTimeMillis();

StructType schema =
StructType.create(
new StructField("c01", DataTypes.BinaryType),
new StructField("c02", DataTypes.BooleanType),
new StructField("c03", DataTypes.ByteType),
new StructField("c04", DataTypes.DateType),
new StructField("c05", DataTypes.DoubleType),
new StructField("c06", DataTypes.FloatType),
new StructField("c07", DataTypes.GeographyType),
new StructField("c08", DataTypes.GeometryType),
new StructField("c09", DataTypes.IntegerType),
new StructField("c10", DataTypes.LongType),
new StructField("c11", DataTypes.ShortType),
new StructField("c12", DataTypes.StringType),
new StructField("c13", DataTypes.TimeType),
new StructField("c14", DataTypes.TimestampType),
new StructField("c15", DataTypes.VariantType));

Row[] data = {
Row.create(
new byte[] {1, 2},
true,
Byte.MIN_VALUE,
Date.valueOf("2024-01-01"),
Double.MIN_VALUE,
Float.MIN_VALUE,
Geography.fromGeoJSON("POINT(30 10)"),
Geometry.fromGeoJSON("POINT(20 40)"),
Integer.MIN_VALUE,
Long.MIN_VALUE,
Short.MIN_VALUE,
"string",
Time.valueOf("16:23:04"),
new Timestamp(milliseconds),
new Variant(1))
};

DataFrame df = getSession().createDataFrame(data, schema);
Row row = df.collect()[0];

assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2});
assert row.getAs(1, Boolean.class);
assert row.getAs(2, Byte.class) == Byte.MIN_VALUE;
assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01"));
assert row.getAs(4, Double.class) == Double.MIN_VALUE;
assert row.getAs(5, Float.class) == Float.MIN_VALUE;
assert row.getAs(6, Geography.class)
.equals(
Geography.fromGeoJSON(
"{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}"));
assert row.getAs(7, Geometry.class)
.equals(
Geometry.fromGeoJSON(
"{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}"));
assert row.getAs(8, Integer.class) == Integer.MIN_VALUE;
assert row.getAs(9, Long.class) == Long.MIN_VALUE;
assert row.getAs(10, Short.class) == Short.MIN_VALUE;
assert row.getAs(11, String.class).equals("string");
assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04"));
assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds));
assert row.getAs(14, Variant.class).equals(new Variant(1));

Row finalRow = row;
assertThrows(
ClassCastException.class,
() -> {
Boolean b = finalRow.getAs(0, Boolean.class);
});
assertThrows(
ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class));

data =
new Row[] {
Row.create(
null, null, null, null, null, null, null, null, null, null, null, null, null, null,
null)
};

df = getSession().createDataFrame(data, schema);
row = df.collect()[0];

assert row.getAs(0, byte[].class) == null;
assert row.getAs(1, Boolean.class) == null;
assert row.getAs(2, Byte.class) == null;
assert row.getAs(3, Date.class) == null;
assert row.getAs(4, Double.class) == null;
assert row.getAs(5, Float.class) == null;
assert row.getAs(6, Geography.class) == null;
assert row.getAs(7, Geometry.class) == null;
assert row.getAs(8, Integer.class) == null;
assert row.getAs(9, Long.class) == null;
assert row.getAs(10, Short.class) == null;
assert row.getAs(11, String.class) == null;
assert row.getAs(12, Time.class) == null;
assert row.getAs(13, Timestamp.class) == null;
assert row.getAs(14, Variant.class) == null;
}
}
114 changes: 112 additions & 2 deletions src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.snowflake.snowpark_test

import com.snowflake.snowpark.types.{Geography, Geometry, Variant}
import com.snowflake.snowpark.types._
import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException}

import java.sql.{Date, Timestamp}
import java.sql.{Date, Time, Timestamp}
import java.time.{Instant, LocalDate}
import java.util

Expand Down Expand Up @@ -233,6 +233,116 @@ class RowSuite extends SNTestBase {
assert(err.message.matches(msg))
}

test("getAs") {
val milliseconds = System.currentTimeMillis()

val schema = StructType(
Seq(
StructField("c01", BinaryType),
StructField("c02", BooleanType),
StructField("c03", ByteType),
StructField("c04", DateType),
StructField("c05", DoubleType),
StructField("c06", FloatType),
StructField("c07", GeographyType),
StructField("c08", GeometryType),
StructField("c09", IntegerType),
StructField("c10", LongType),
StructField("c11", ShortType),
StructField("c12", StringType),
StructField("c13", TimeType),
StructField("c14", TimestampType),
StructField("c15", VariantType)))

var data = Seq(
Row(
Array[Byte](1, 9),
true,
Byte.MinValue,
Date.valueOf("2024-01-01"),
Double.MinValue,
Float.MinPositiveValue,
Geography.fromGeoJSON("POINT(30 10)"),
Geometry.fromGeoJSON("POINT(20 40)"),
Int.MinValue,
Long.MinValue,
Short.MinValue,
"string",
Time.valueOf("16:23:04"),
new Timestamp(milliseconds),
new Variant(1)))

var df = session.createDataFrame(data, schema)
var row = df.collect()(0)

assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9))
assert(row.getAs[Boolean](1))
assert(row.getAs[Byte](2) == Byte.MinValue)
assert(row.getAs[Date](3) == Date.valueOf("2024-01-01"))
assert(row.getAs[Double](4) == Double.MinValue)
assert(row.getAs[Float](5) == Float.MinPositiveValue)
assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{
| "coordinates": [
| 30,
| 10
| ],
| "type": "Point"
|}""".stripMargin))
assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{
| "coordinates": [
| 2.000000000000000e+01,
| 4.000000000000000e+01
| ],
| "type": "Point"
|}""".stripMargin))
assert(row.getAs[Int](8) == Int.MinValue)
assert(row.getAs[Long](9) == Long.MinValue)
assert(row.getAs[Short](10) == Short.MinValue)
assert(row.getAs[String](11) == "string")
assert(row.getAs[Time](12) == Time.valueOf("16:23:04"))
assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds))
assert(row.getAs[Variant](14) == new Variant(1))
assertThrows[ClassCastException](row.getAs[Boolean](0))
assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1))

data = Seq(
Row(
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null,
null))

df = session.createDataFrame(data, schema)
row = df.collect()(0)

assert(row.getAs[Array[Byte]](0) == null)
assert(!row.getAs[Boolean](1))
assert(row.getAs[Byte](2) == 0)
assert(row.getAs[Date](3) == null)
assert(row.getAs[Double](4) == 0)
assert(row.getAs[Float](5) == 0)
assert(row.getAs[Geography](6) == null)
assert(row.getAs[Geometry](7) == null)
assert(row.getAs[Int](8) == 0)
assert(row.getAs[Long](9) == 0)
assert(row.getAs[Short](10) == 0)
assert(row.getAs[String](11) == null)
assert(row.getAs[Time](12) == null)
assert(row.getAs[Timestamp](13) == null)
assert(row.getAs[Variant](14) == null)
}

test("hashCode") {
val row1 = Row(1, 2, 3)
val row2 = Row("str", null, 3)
Expand Down

0 comments on commit f7c9f9e

Please sign in to comment.