From e57e78a5c4b765603fbc25c259e1ba20fcce46c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Fallas=20Avenda=C3=B1o?= Date: Mon, 7 Oct 2024 12:09:41 -0600 Subject: [PATCH] [SIT-2214] Add support for `com.snowflake.snowpark.Row.getAs(String)` (#166) * [SIT-2214] Add support for `com.snowflake.snowpark.Row.getAs(String)` Adds an implementation for the `Row.getAs(String)` for retreving values from a `Row` using field name. * Add missing Java methods --- .../java/com/snowflake/snowpark_java/Row.java | 44 +++++++++ .../snowpark_java/types/StructType.java | 12 +++ .../scala/com/snowflake/snowpark/Row.scala | 59 ++++++++++-- .../snowpark/internal/ServerConnection.scala | 95 ++++++++++--------- .../snowflake/snowpark/types/StructType.scala | 18 ++++ .../snowflake/snowpark_test/JavaRowSuite.java | 44 +++++++++ .../snowflake/snowpark_test/RowSuite.scala | 27 ++++++ 7 files changed, 246 insertions(+), 53 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 74475e2d..07403cb1 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -425,6 +425,18 @@ public Row getObject(int index) { return (Row) get(index); } + /** + * Returns the index of the field with the specified name. + * + * @param fieldName the name of the field. + * @return the index of the specified field. + * @throws UnsupportedOperationException if schema information is not available. + * @since 1.15.0 + */ + public int fieldIndex(String fieldName) { + return this.scalaRow.fieldIndex(fieldName); + } + /** * Returns the value at the specified column index and casts it to the desired type {@code T}. * @@ -483,6 +495,38 @@ public T getAs(int index, Class clazz) return (T) get(index); } + /** + * Returns the field value for the specified field name and casts it to the desired type {@code + * T}. + * + *

Example: + * + *

{@code
+   * StructType schema =
+   *     StructType.create(
+   *        new StructField("name", DataTypes.StringType),
+   *        new StructField("val", DataTypes.IntegerType));
+   * Row[] data = { Row.create("Alice", 1) };
+   * DataFrame df = session.createDataFrame(data, schema);
+   * Row row = df.collect()[0];
+   *
+   * row.getAs("name", String.class); // Returns "Alice" as a String
+   * row.getAs("val", Integer.class); // Returns 1 as an Int
+   * }
+ * + * @param fieldName the name of the field within the row. + * @param clazz the {@code Class} object representing the type {@code T}. + * @param the expected type of the value for the specified field name. + * @return the field value for the specified field name cast to type {@code T}. + * @throws ClassCastException if the value of the field cannot be cast to type {@code T}. + * @throws IllegalArgumentException if the name of the field is not part of the row schema. + * @throws UnsupportedOperationException if the schema information is not available. + * @since 1.15.0 + */ + public T getAs(String fieldName, Class clazz) { + return this.getAs(this.scalaRow.fieldIndex(fieldName), clazz); + } + /** * Generates a string value to represent the content of this row. * diff --git a/src/main/java/com/snowflake/snowpark_java/types/StructType.java b/src/main/java/com/snowflake/snowpark_java/types/StructType.java index 63998e3d..2b738129 100644 --- a/src/main/java/com/snowflake/snowpark_java/types/StructType.java +++ b/src/main/java/com/snowflake/snowpark_java/types/StructType.java @@ -58,6 +58,18 @@ private static com.snowflake.snowpark.types.StructField[] toScalaFieldsArray( return result; } + /** + * Return the index of the specified field. + * + * @param fieldName the name of the field. + * @return the index of the field with the specified name. + * @throws IllegalArgumentException if the given field name does not exist in the schema. + * @since 1.15.0 + */ + public int fieldIndex(String fieldName) { + return this.scalaStructType.fieldIndex(fieldName); + } + /** * Retrieves the names of StructField. * diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index 34eb9bbf..8fb0df20 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -2,7 +2,7 @@ package com.snowflake.snowpark import java.sql.{Date, Time, Timestamp} import com.snowflake.snowpark.internal.ErrorMessage -import com.snowflake.snowpark.types.{Geography, Geometry, Variant} +import com.snowflake.snowpark.types.{Geography, Geometry, StructType, Variant} import scala.reflect.ClassTag import scala.util.hashing.MurmurHash3 @@ -16,19 +16,22 @@ object Row { * Returns a [[Row]] based on the given values. * @since 0.1.0 */ - def apply(values: Any*): Row = new Row(values.toArray) + def apply(values: Any*): Row = new Row(values.toArray, None) /** * Return a [[Row]] based on the values in the given Seq. * @since 0.1.0 */ - def fromSeq(values: Seq[Any]): Row = new Row(values.toArray) + def fromSeq(values: Seq[Any]): Row = new Row(values.toArray, None) /** * Return a [[Row]] based on the values in the given Array. * @since 0.2.0 */ - def fromArray(values: Array[Any]): Row = new Row(values) + def fromArray(values: Array[Any]): Row = new Row(values, None) + + private[snowpark] def fromSeqWithSchema(values: Seq[Any], schema: Option[StructType]): Row = + new Row(values.toArray, schema) private[snowpark] def fromMap(map: Map[String, Any]): Row = new SnowflakeObject(map) @@ -36,7 +39,7 @@ object Row { private[snowpark] class SnowflakeObject private[snowpark] ( private[snowpark] val map: Map[String, Any]) - extends Row(map.values.toArray) { + extends Row(map.values.toArray, None) { override def toString: String = convertValueToString(this) } @@ -47,7 +50,7 @@ private[snowpark] class SnowflakeObject private[snowpark] ( * @groupname utl Utility Functions * @since 0.1.0 */ -class Row protected (values: Array[Any]) extends Serializable { +class Row protected (values: Array[Any], schema: Option[StructType]) extends Serializable { /** * Converts this [[Row]] to a Seq @@ -89,7 +92,7 @@ class Row protected (values: Array[Any]) extends Serializable { * @since 0.1.0 * @group utl */ - def copy(): Row = new Row(values) + def copy(): Row = new Row(values, schema) /** * Returns a clone of this row object. Alias of [[copy]] @@ -367,6 +370,48 @@ class Row protected (values: Array[Any]) extends Serializable { getAs[Map[T, U]](index) } + /** + * Returns the index of the field with the specified name. + * + * @param fieldName the name of the field. + * @return the index of the specified field. + * @throws UnsupportedOperationException if schema information is not available. + * @since 1.15.0 + */ + def fieldIndex(fieldName: String): Int = { + var schema = this.schema.getOrElse( + throw new UnsupportedOperationException("Cannot get field index for row without schema")) + schema.fieldIndex(fieldName) + } + + /** + * Returns the value for the specified field name and casts it to the desired type `T`. + * + * Example: + * + * {{{ + * val schema = + * StructType(Seq(StructField("name", StringType), StructField("value", IntegerType))) + * val data = Seq(Row("Alice", 1)) + * val df = session.createDataFrame(data, schema) + * val row = df.collect()(0) + * + * row.getAs[String]("name") // Returns "Alice" as a String + * row.getAs[Int]("value") // Returns 1 as an Int + * }}} + * + * @param fieldName the name of the field within the row. + * @tparam T the expected type of the value for the specified field name. + * @return the value for the specified field name cast to type `T`. + * @throws ClassCastException if the value of the field cannot be cast to type `T`. + * @throws IllegalArgumentException if the name of the field is not part of the row schema. + * @throws UnsupportedOperationException if the schema information is not available. + * @group getter + * @since 1.15.0 + */ + def getAs[T](fieldName: String)(implicit classTag: ClassTag[T]): T = + getAs[T](fieldIndex(fieldName)) + /** * Returns the value at the specified column index and casts it to the desired type `T`. * diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index 92728eaf..a2281925 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -330,6 +330,7 @@ private[snowpark] class ServerConnection( val data = statement.getResultSet val schema = ServerConnection.convertResultMetaToAttribute(data.getMetaData) + val schemaOption = Some(StructType.fromAttributes(schema)) lazy val geographyOutputFormat = getParameterValue(ParameterUtils.GeographyOutputFormat) lazy val geometryOutputFormat = getParameterValue(ParameterUtils.GeometryOutputFormat) @@ -343,53 +344,55 @@ private[snowpark] class ServerConnection( private def readNext(): Unit = { _hasNext = data.next() _currentRow = if (_hasNext) { - Row.fromSeq(schema.zipWithIndex.map { - case (attribute, index) => - val resultIndex: Int = index + 1 - val resultSetExt = SnowflakeResultSetExt(data) - if (resultSetExt.isNull(resultIndex)) { - null - } else { - attribute.dataType match { - case VariantType => data.getString(resultIndex) - case _: StructuredArrayType | _: StructuredMapType | _: StructType => - resultSetExt.getObject(resultIndex) - case ArrayType(StringType) => data.getString(resultIndex) - case MapType(StringType, StringType) => data.getString(resultIndex) - case StringType => data.getString(resultIndex) - case _: DecimalType => data.getBigDecimal(resultIndex) - case DoubleType => data.getDouble(resultIndex) - case FloatType => data.getFloat(resultIndex) - case BooleanType => data.getBoolean(resultIndex) - case BinaryType => data.getBytes(resultIndex) - case DateType => data.getDate(resultIndex) - case TimeType => data.getTime(resultIndex) - case ByteType => data.getByte(resultIndex) - case IntegerType => data.getInt(resultIndex) - case LongType => data.getLong(resultIndex) - case TimestampType => data.getTimestamp(resultIndex) - case ShortType => data.getShort(resultIndex) - case GeographyType => - geographyOutputFormat match { - case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex)) - case _ => - throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT( - geographyOutputFormat) - } - case GeometryType => - geometryOutputFormat match { - case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex)) - case _ => - throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT( - geometryOutputFormat) - } - case _ => - // ArrayType, StructType, MapType - throw new UnsupportedOperationException( - s"Unsupported type: ${attribute.dataType}") + Row.fromSeqWithSchema( + schema.zipWithIndex.map { + case (attribute, index) => + val resultIndex: Int = index + 1 + val resultSetExt = SnowflakeResultSetExt(data) + if (resultSetExt.isNull(resultIndex)) { + null + } else { + attribute.dataType match { + case VariantType => data.getString(resultIndex) + case _: StructuredArrayType | _: StructuredMapType | _: StructType => + resultSetExt.getObject(resultIndex) + case ArrayType(StringType) => data.getString(resultIndex) + case MapType(StringType, StringType) => data.getString(resultIndex) + case StringType => data.getString(resultIndex) + case _: DecimalType => data.getBigDecimal(resultIndex) + case DoubleType => data.getDouble(resultIndex) + case FloatType => data.getFloat(resultIndex) + case BooleanType => data.getBoolean(resultIndex) + case BinaryType => data.getBytes(resultIndex) + case DateType => data.getDate(resultIndex) + case TimeType => data.getTime(resultIndex) + case ByteType => data.getByte(resultIndex) + case IntegerType => data.getInt(resultIndex) + case LongType => data.getLong(resultIndex) + case TimestampType => data.getTimestamp(resultIndex) + case ShortType => data.getShort(resultIndex) + case GeographyType => + geographyOutputFormat match { + case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex)) + case _ => + throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT( + geographyOutputFormat) + } + case GeometryType => + geometryOutputFormat match { + case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex)) + case _ => + throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT( + geometryOutputFormat) + } + case _ => + // ArrayType, StructType, MapType + throw new UnsupportedOperationException( + s"Unsupported type: ${attribute.dataType}") + } } - } - }) + }, + schemaOption) } else { // After all rows are consumed, close the statement to release resource close() diff --git a/src/main/scala/com/snowflake/snowpark/types/StructType.scala b/src/main/scala/com/snowflake/snowpark/types/StructType.scala index ff8869df..d985a98d 100644 --- a/src/main/scala/com/snowflake/snowpark/types/StructType.scala +++ b/src/main/scala/com/snowflake/snowpark/types/StructType.scala @@ -40,6 +40,10 @@ case class StructType(fields: Array[StructField] = Array()) extends DataType with Seq[StructField] { + private lazy val fieldPositions = scala.collection.immutable + .SortedMap(fields.zipWithIndex.map(tuple => (tuple._1.name -> tuple._2)): _*)( + scala.math.Ordering.comparatorToOrdering(String.CASE_INSENSITIVE_ORDER)) + /** * Returns the total number of [[StructField]] * @since 0.1.0 @@ -101,6 +105,20 @@ case class StructType(fields: Array[StructField] = Array()) nameToField(name).getOrElse( throw new IllegalArgumentException(s"$name does not exits. Names: ${names.mkString(", ")}")) + /** + * Return the index of the specified field. + * + * @param fieldName the name of the field. + * @return the index of the field with the specified name. + * @throws IllegalArgumentException if the given field name does not exist in the schema. + * @since 1.15.0 + */ + def fieldIndex(fieldName: String): Int = { + fieldPositions.getOrElse( + fieldName, + throw new IllegalArgumentException("Field " + fieldName + " does not exist")) + } + protected[snowpark] def toAttributes: Seq[Attribute] = { /* * When user provided schema is used in a SnowflakePlan, we have to diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index f8918292..88294511 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -602,4 +602,48 @@ public void getAsWithStructuredArray() { }, getSession()); } + + @Test + public void getAsWithFieldName() { + StructType schema = + StructType.create( + new StructField("EmpName", DataTypes.StringType), + new StructField("NumVal", DataTypes.IntegerType)); + + Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)}; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert (row.getAs("EmpName", String.class) == row.getAs(0, String.class)); + assert (row.getAs("EmpName", String.class).charAt(3) == 'd'); + assert (row.getAs("NumVal", Integer.class) == row.getAs(1, Integer.class)); + + assert (row.getAs("EMPNAME", String.class) == row.getAs(0, String.class)); + + assertThrows( + IllegalArgumentException.class, () -> row.getAs("NonExistingColumn", Integer.class)); + + Row rowWithoutSchema = Row.create(40, "Alice"); + assertThrows( + UnsupportedOperationException.class, + () -> rowWithoutSchema.getAs("NonExistingColumn", Integer.class)); + } + + @Test + public void fieldIndex() { + StructType schema = + StructType.create( + new StructField("EmpName", DataTypes.StringType), + new StructField("NumVal", DataTypes.IntegerType)); + + Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)}; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert (row.fieldIndex("EmpName") == 0); + assert (row.fieldIndex("NumVal") == 1); + assertThrows(IllegalArgumentException.class, () -> row.fieldIndex("NonExistingColumn")); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index df87666f..c92b4c7e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -404,6 +404,33 @@ class RowSuite extends SNTestBase { } } + test("getAs with field name") { + val schema = + StructType(Seq(StructField("EmpName", StringType), StructField("NumVal", IntegerType))) + val df = session.createDataFrame(Seq(Row("abcd", 10), Row("efgh", 20)), schema) + val row = df.collect()(0) + + assert(row.getAs[String]("EmpName") == row.getAs[String](0)) + assert(row.getAs[String]("EmpName").charAt(3) == 'd') + assert(row.getAs[Int]("NumVal") == row.getAs[Int](1)) + + assert(row.getAs[String]("EMPNAME") == row.getAs[String](0)) + + assertThrows[IllegalArgumentException](row.getAs[String]("NonExistingColumn")) + + val rowWithoutSchema = Row(40, "Alice") + assertThrows[UnsupportedOperationException]( + rowWithoutSchema.getAs[Integer]("NonExistingColumn")); + } + + test("fieldIndex") { + val schema = + StructType(Seq(StructField("EmpName", StringType), StructField("NumVal", IntegerType))) + assert(schema.fieldIndex("EmpName") == 0) + assert(schema.fieldIndex("NumVal") == 1) + assertThrows[IllegalArgumentException](schema.fieldIndex("NonExistingColumn")) + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3)