Skip to content

Commit

Permalink
SIT-2037 Add support for com.snowflake.snowpark.Row.getAs function (#…
Browse files Browse the repository at this point in the history
…148)

* Add support for com.snowflake.snowpark.Row.getAs function

* Fix code formatting

* Add tests scenarios for structured maps and structured arrays

* Use collections compatible with Java 8

* Update tests for structured types

* Remove type inference for Java variables

* Update the since tag to 1.15.0
  • Loading branch information
sfc-gh-fgonzalezmendez authored Sep 10, 2024
1 parent 27d6ece commit eed60a4
Show file tree
Hide file tree
Showing 4 changed files with 439 additions and 7 deletions.
58 changes: 58 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,64 @@ public Row getObject(int index) {
return (Row) get(index);
}

/**
* Returns the value at the specified column index and casts it to the desired type {@code T}.
*
* <p>Example:
*
* <pre>{@code
* Row row = Row.create(1, "Alice", 95.5);
* row.getAs(0, Integer.class); // Returns 1 as an Int
* row.getAs(1, String.class); // Returns "Alice" as a String
* row.getAs(2, Double.class); // Returns 95.5 as a Double
* }</pre>
*
* @param index the zero-based column index within the row.
* @param clazz the {@code Class} object representing the type {@code T}.
* @param <T> the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type {@code T}.
* @throws ClassCastException if the value at the given index cannot be cast to type {@code T}.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @since 1.15.0
*/
@SuppressWarnings("unchecked")
public <T> T getAs(int index, Class<T> clazz)
throws ClassCastException, ArrayIndexOutOfBoundsException {
if (isNullAt(index)) {
return (T) get(index);
}

if (clazz == Byte.class) {
return (T) (Object) getByte(index);
}

if (clazz == Double.class) {
return (T) (Object) getDouble(index);
}

if (clazz == Float.class) {
return (T) (Object) getFloat(index);
}

if (clazz == Integer.class) {
return (T) (Object) getInt(index);
}

if (clazz == Long.class) {
return (T) (Object) getLong(index);
}

if (clazz == Short.class) {
return (T) (Object) getShort(index);
}

if (clazz == Variant.class) {
return (T) getVariant(index);
}

return (T) get(index);
}

/**
* Generates a string value to represent the content of this row.
*
Expand Down
40 changes: 35 additions & 5 deletions src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}

/**
* Returns the value at the specified column index and casts it to the desired type `T`.
*
* Example:
* {{{
* val row = Row(1, "Alice", 95.5)
* row.getAs[Int](0) // Returns 1 as an Int
* row.getAs[String](1) // Returns "Alice" as a String
* row.getAs[Double](2) // Returns 95.5 as a Double
* }}}
*
* @param index the zero-based column index within the row.
* @tparam T the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type `T`.
* @throws ClassCastException if the value at the given index cannot be cast to type `T`.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @group getter
* @since 1.15.0
*/
def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = {
classTag.runtimeClass match {
case _ if isNullAt(index) => get(index).asInstanceOf[T]
case c if c == classOf[Byte] => getByte(index).asInstanceOf[T]
case c if c == classOf[Double] => getDouble(index).asInstanceOf[T]
case c if c == classOf[Float] => getFloat(index).asInstanceOf[T]
case c if c == classOf[Int] => getInt(index).asInstanceOf[T]
case c if c == classOf[Long] => getLong(index).asInstanceOf[T]
case c if c == classOf[Short] => getShort(index).asInstanceOf[T]
case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T]
case _ => get(index).asInstanceOf[T]
}
}

protected def convertValueToString(value: Any): String =
value match {
case null => "null"
Expand Down Expand Up @@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable {
.map(convertValueToString)
.mkString("Row[", ",", "]")

private def getAs[T](index: Int): T = get(index).asInstanceOf[T]

private def getAnyValAs[T <: AnyVal](index: Int): T =
private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T =
if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null")
else getAs[T](index)

else getAs[T](index)(classTag)
}
173 changes: 173 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package com.snowflake.snowpark_test;

import static org.junit.Assert.assertThrows;

import com.snowflake.snowpark_java.DataFrame;
import com.snowflake.snowpark_java.Row;
import com.snowflake.snowpark_java.types.*;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import org.junit.Test;

public class JavaRowSuite extends TestBase {
Expand Down Expand Up @@ -429,4 +434,172 @@ public void testGetRow() {
},
getSession());
}

@Test
public void getAs() {
long milliseconds = System.currentTimeMillis();

StructType schema =
StructType.create(
new StructField("c01", DataTypes.BinaryType),
new StructField("c02", DataTypes.BooleanType),
new StructField("c03", DataTypes.ByteType),
new StructField("c04", DataTypes.DateType),
new StructField("c05", DataTypes.DoubleType),
new StructField("c06", DataTypes.FloatType),
new StructField("c07", DataTypes.GeographyType),
new StructField("c08", DataTypes.GeometryType),
new StructField("c09", DataTypes.IntegerType),
new StructField("c10", DataTypes.LongType),
new StructField("c11", DataTypes.ShortType),
new StructField("c12", DataTypes.StringType),
new StructField("c13", DataTypes.TimeType),
new StructField("c14", DataTypes.TimestampType),
new StructField("c15", DataTypes.VariantType));

Row[] data = {
Row.create(
new byte[] {1, 2},
true,
Byte.MIN_VALUE,
Date.valueOf("2024-01-01"),
Double.MIN_VALUE,
Float.MIN_VALUE,
Geography.fromGeoJSON("POINT(30 10)"),
Geometry.fromGeoJSON("POINT(20 40)"),
Integer.MIN_VALUE,
Long.MIN_VALUE,
Short.MIN_VALUE,
"string",
Time.valueOf("16:23:04"),
new Timestamp(milliseconds),
new Variant(1))
};

DataFrame df = getSession().createDataFrame(data, schema);
Row row = df.collect()[0];

assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2});
assert row.getAs(1, Boolean.class);
assert row.getAs(2, Byte.class) == Byte.MIN_VALUE;
assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01"));
assert row.getAs(4, Double.class) == Double.MIN_VALUE;
assert row.getAs(5, Float.class) == Float.MIN_VALUE;
assert row.getAs(6, Geography.class)
.equals(
Geography.fromGeoJSON(
"{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}"));
assert row.getAs(7, Geometry.class)
.equals(
Geometry.fromGeoJSON(
"{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}"));
assert row.getAs(8, Integer.class) == Integer.MIN_VALUE;
assert row.getAs(9, Long.class) == Long.MIN_VALUE;
assert row.getAs(10, Short.class) == Short.MIN_VALUE;
assert row.getAs(11, String.class).equals("string");
assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04"));
assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds));
assert row.getAs(14, Variant.class).equals(new Variant(1));

Row finalRow = row;
assertThrows(
ClassCastException.class,
() -> {
Boolean b = finalRow.getAs(0, Boolean.class);
});
assertThrows(ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class));

data =
new Row[] {
Row.create(
null, null, null, null, null, null, null, null, null, null, null, null, null, null,
null)
};

df = getSession().createDataFrame(data, schema);
row = df.collect()[0];

assert row.getAs(0, byte[].class) == null;
assert row.getAs(1, Boolean.class) == null;
assert row.getAs(2, Byte.class) == null;
assert row.getAs(3, Date.class) == null;
assert row.getAs(4, Double.class) == null;
assert row.getAs(5, Float.class) == null;
assert row.getAs(6, Geography.class) == null;
assert row.getAs(7, Geometry.class) == null;
assert row.getAs(8, Integer.class) == null;
assert row.getAs(9, Long.class) == null;
assert row.getAs(10, Short.class) == null;
assert row.getAs(11, String.class) == null;
assert row.getAs(12, Time.class) == null;
assert row.getAs(13, Timestamp.class) == null;
assert row.getAs(14, Variant.class) == null;
}

@Test
public void getAsWithStructuredMap() {
structuredTypeTest(
() -> {
String query =
"SELECT "
+ "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,"
+ "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,"
+ "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3";

DataFrame df = getSession().sql(query);
Row row = df.collect()[0];

Map<?, ?> map1 = row.getAs(0, Map.class);
assert (Long) map1.get("a") == 1L;
assert (Long) map1.get("b") == 2L;

Map<?, ?> map2 = row.getAs(1, Map.class);
assert map2.get(1L).equals("a");
assert map2.get(2L).equals("b");

Map<?, ?> map3 = row.getAs(2, Map.class);
Map<String, Long> map3ExpectedInnerMap = new HashMap<>();
map3ExpectedInnerMap.put("a", 1L);
map3ExpectedInnerMap.put("b", 2L);
assert map3.get(1L).equals(map3ExpectedInnerMap);
assert map3.get(2L).equals(Collections.singletonMap("c", 3L));
},
getSession());
}

@Test
public void getAsWithStructuredArray() {
structuredTypeTest(
() -> {
TimeZone oldTimeZone = TimeZone.getDefault();
try {
TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"));

String query =
"SELECT "
+ "[1,2,3]::ARRAY(NUMBER) AS arr1,"
+ "['a','b']::ARRAY(VARCHAR) AS arr2,"
+ "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,"
+ "[[1,2]]::ARRAY(ARRAY) AS arr4";

DataFrame df = getSession().sql(query);
Row row = df.collect()[0];

ArrayList<?> array1 = row.getAs(0, ArrayList.class);
assert array1.equals(Arrays.asList(1L, 2L, 3L));

ArrayList<?> array2 = row.getAs(1, ArrayList.class);
assert array2.equals(Arrays.asList("a", "b"));

ArrayList<?> array3 = row.getAs(2, ArrayList.class);
assert array3.equals(Collections.singletonList(new Timestamp(31000000000L)));

ArrayList<?> array4 = row.getAs(3, ArrayList.class);
assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]"));
} finally {
TimeZone.setDefault(oldTimeZone);
}
},
getSession());
}
}
Loading

0 comments on commit eed60a4

Please sign in to comment.