Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SIT-2214] Add support for com.snowflake.snowpark.Row.getAs(String) #166

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,18 @@ public Row getObject(int index) {
return (Row) get(index);
}

/**
* Returns the index of the field with the specified name.
*
* @param fieldName the name of the field.
* @return the index of the specified field.
* @throws UnsupportedOperationException if schema information is not available.
* @since 1.15.0
*/
public int fieldIndex(String fieldName) {
return this.scalaRow.fieldIndex(fieldName);
}

/**
* Returns the value at the specified column index and casts it to the desired type {@code T}.
*
Expand Down Expand Up @@ -483,6 +495,38 @@ public <T> T getAs(int index, Class<T> clazz)
return (T) get(index);
}

/**
* Returns the field value for the specified field name and casts it to the desired type {@code
* T}.
*
* <p>Example:
*
* <pre>{@code
* StructType schema =
* StructType.create(
* new StructField("name", DataTypes.StringType),
* new StructField("val", DataTypes.IntegerType));
* Row[] data = { Row.create("Alice", 1) };
* DataFrame df = session.createDataFrame(data, schema);
* Row row = df.collect()[0];
*
* row.getAs("name", String.class); // Returns "Alice" as a String
* row.getAs("val", Integer.class); // Returns 1 as an Int
* }</pre>
*
* @param fieldName the name of the field within the row.
* @param clazz the {@code Class} object representing the type {@code T}.
* @param <T> the expected type of the value for the specified field name.
* @return the field value for the specified field name cast to type {@code T}.
* @throws ClassCastException if the value of the field cannot be cast to type {@code T}.
* @throws IllegalArgumentException if the name of the field is not part of the row schema.
* @throws UnsupportedOperationException if the schema information is not available.
* @since 1.15.0
*/
public <T> T getAs(String fieldName, Class<T> clazz) {
return this.getAs(this.scalaRow.fieldIndex(fieldName), clazz);
}

/**
* Generates a string value to represent the content of this row.
*
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/types/StructType.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ private static com.snowflake.snowpark.types.StructField[] toScalaFieldsArray(
return result;
}

/**
* Return the index of the specified field.
*
* @param fieldName the name of the field.
* @return the index of the field with the specified name.
* @throws IllegalArgumentException if the given field name does not exist in the schema.
* @since 1.15.0
*/
public int fieldIndex(String fieldName) {
return this.scalaStructType.fieldIndex(fieldName);
}

/**
* Retrieves the names of StructField.
*
Expand Down
59 changes: 52 additions & 7 deletions src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.snowflake.snowpark

import java.sql.{Date, Time, Timestamp}
import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.types.{Geography, Geometry, Variant}
import com.snowflake.snowpark.types.{Geography, Geometry, StructType, Variant}

import scala.reflect.ClassTag
import scala.util.hashing.MurmurHash3
Expand All @@ -16,27 +16,30 @@ object Row {
* Returns a [[Row]] based on the given values.
* @since 0.1.0
*/
def apply(values: Any*): Row = new Row(values.toArray)
def apply(values: Any*): Row = new Row(values.toArray, None)

/**
* Return a [[Row]] based on the values in the given Seq.
* @since 0.1.0
*/
def fromSeq(values: Seq[Any]): Row = new Row(values.toArray)
def fromSeq(values: Seq[Any]): Row = new Row(values.toArray, None)

/**
* Return a [[Row]] based on the values in the given Array.
* @since 0.2.0
*/
def fromArray(values: Array[Any]): Row = new Row(values)
def fromArray(values: Array[Any]): Row = new Row(values, None)

private[snowpark] def fromSeqWithSchema(values: Seq[Any], schema: Option[StructType]): Row =
new Row(values.toArray, schema)

private[snowpark] def fromMap(map: Map[String, Any]): Row =
new SnowflakeObject(map)
}

private[snowpark] class SnowflakeObject private[snowpark] (
private[snowpark] val map: Map[String, Any])
extends Row(map.values.toArray) {
extends Row(map.values.toArray, None) {
override def toString: String = convertValueToString(this)
}

Expand All @@ -47,7 +50,7 @@ private[snowpark] class SnowflakeObject private[snowpark] (
* @groupname utl Utility Functions
* @since 0.1.0
*/
class Row protected (values: Array[Any]) extends Serializable {
class Row protected (values: Array[Any], schema: Option[StructType]) extends Serializable {

/**
* Converts this [[Row]] to a Seq
Expand Down Expand Up @@ -89,7 +92,7 @@ class Row protected (values: Array[Any]) extends Serializable {
* @since 0.1.0
* @group utl
*/
def copy(): Row = new Row(values)
def copy(): Row = new Row(values, schema)

/**
* Returns a clone of this row object. Alias of [[copy]]
Expand Down Expand Up @@ -367,6 +370,48 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}

/**
* Returns the index of the field with the specified name.
*
* @param fieldName the name of the field.
* @return the index of the specified field.
* @throws UnsupportedOperationException if schema information is not available.
* @since 1.15.0
*/
def fieldIndex(fieldName: String): Int = {
var schema = this.schema.getOrElse(
throw new UnsupportedOperationException("Cannot get field index for row without schema"))
schema.fieldIndex(fieldName)
}

/**
* Returns the value for the specified field name and casts it to the desired type `T`.
*
* Example:
*
* {{{
* val schema =
* StructType(Seq(StructField("name", StringType), StructField("value", IntegerType)))
* val data = Seq(Row("Alice", 1))
* val df = session.createDataFrame(data, schema)
* val row = df.collect()(0)
*
* row.getAs[String]("name") // Returns "Alice" as a String
* row.getAs[Int]("value") // Returns 1 as an Int
* }}}
*
* @param fieldName the name of the field within the row.
* @tparam T the expected type of the value for the specified field name.
* @return the value for the specified field name cast to type `T`.
* @throws ClassCastException if the value of the field cannot be cast to type `T`.
* @throws IllegalArgumentException if the name of the field is not part of the row schema.
* @throws UnsupportedOperationException if the schema information is not available.
* @group getter
* @since 1.15.0
*/
def getAs[T](fieldName: String)(implicit classTag: ClassTag[T]): T =
getAs[T](fieldIndex(fieldName))

/**
* Returns the value at the specified column index and casts it to the desired type `T`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ private[snowpark] class ServerConnection(
val data = statement.getResultSet

val schema = ServerConnection.convertResultMetaToAttribute(data.getMetaData)
val schemaOption = Some(StructType.fromAttributes(schema))

lazy val geographyOutputFormat = getParameterValue(ParameterUtils.GeographyOutputFormat)
lazy val geometryOutputFormat = getParameterValue(ParameterUtils.GeometryOutputFormat)
Expand All @@ -343,53 +344,55 @@ private[snowpark] class ServerConnection(
private def readNext(): Unit = {
_hasNext = data.next()
_currentRow = if (_hasNext) {
Row.fromSeq(schema.zipWithIndex.map {
case (attribute, index) =>
val resultIndex: Int = index + 1
val resultSetExt = SnowflakeResultSetExt(data)
if (resultSetExt.isNull(resultIndex)) {
null
} else {
attribute.dataType match {
case VariantType => data.getString(resultIndex)
case _: StructuredArrayType | _: StructuredMapType | _: StructType =>
resultSetExt.getObject(resultIndex)
case ArrayType(StringType) => data.getString(resultIndex)
case MapType(StringType, StringType) => data.getString(resultIndex)
case StringType => data.getString(resultIndex)
case _: DecimalType => data.getBigDecimal(resultIndex)
case DoubleType => data.getDouble(resultIndex)
case FloatType => data.getFloat(resultIndex)
case BooleanType => data.getBoolean(resultIndex)
case BinaryType => data.getBytes(resultIndex)
case DateType => data.getDate(resultIndex)
case TimeType => data.getTime(resultIndex)
case ByteType => data.getByte(resultIndex)
case IntegerType => data.getInt(resultIndex)
case LongType => data.getLong(resultIndex)
case TimestampType => data.getTimestamp(resultIndex)
case ShortType => data.getShort(resultIndex)
case GeographyType =>
geographyOutputFormat match {
case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT(
geographyOutputFormat)
}
case GeometryType =>
geometryOutputFormat match {
case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT(
geometryOutputFormat)
}
case _ =>
// ArrayType, StructType, MapType
throw new UnsupportedOperationException(
s"Unsupported type: ${attribute.dataType}")
Row.fromSeqWithSchema(
schema.zipWithIndex.map {
case (attribute, index) =>
val resultIndex: Int = index + 1
val resultSetExt = SnowflakeResultSetExt(data)
if (resultSetExt.isNull(resultIndex)) {
null
} else {
attribute.dataType match {
case VariantType => data.getString(resultIndex)
case _: StructuredArrayType | _: StructuredMapType | _: StructType =>
resultSetExt.getObject(resultIndex)
case ArrayType(StringType) => data.getString(resultIndex)
case MapType(StringType, StringType) => data.getString(resultIndex)
case StringType => data.getString(resultIndex)
case _: DecimalType => data.getBigDecimal(resultIndex)
case DoubleType => data.getDouble(resultIndex)
case FloatType => data.getFloat(resultIndex)
case BooleanType => data.getBoolean(resultIndex)
case BinaryType => data.getBytes(resultIndex)
case DateType => data.getDate(resultIndex)
case TimeType => data.getTime(resultIndex)
case ByteType => data.getByte(resultIndex)
case IntegerType => data.getInt(resultIndex)
case LongType => data.getLong(resultIndex)
case TimestampType => data.getTimestamp(resultIndex)
case ShortType => data.getShort(resultIndex)
case GeographyType =>
geographyOutputFormat match {
case "GeoJSON" => Geography.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOGRAPHY_FORMAT(
geographyOutputFormat)
}
case GeometryType =>
geometryOutputFormat match {
case "GeoJSON" => Geometry.fromGeoJSON(data.getString(resultIndex))
case _ =>
throw ErrorMessage.MISC_UNSUPPORTED_GEOMETRY_FORMAT(
geometryOutputFormat)
}
case _ =>
// ArrayType, StructType, MapType
throw new UnsupportedOperationException(
s"Unsupported type: ${attribute.dataType}")
}
}
}
})
},
schemaOption)
} else {
// After all rows are consumed, close the statement to release resource
close()
Expand Down
18 changes: 18 additions & 0 deletions src/main/scala/com/snowflake/snowpark/types/StructType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ case class StructType(fields: Array[StructField] = Array())
extends DataType
with Seq[StructField] {

private lazy val fieldPositions = scala.collection.immutable
.SortedMap(fields.zipWithIndex.map(tuple => (tuple._1.name -> tuple._2)): _*)(
scala.math.Ordering.comparatorToOrdering(String.CASE_INSENSITIVE_ORDER))

/**
* Returns the total number of [[StructField]]
* @since 0.1.0
Expand Down Expand Up @@ -101,6 +105,20 @@ case class StructType(fields: Array[StructField] = Array())
nameToField(name).getOrElse(
throw new IllegalArgumentException(s"$name does not exits. Names: ${names.mkString(", ")}"))

/**
* Return the index of the specified field.
*
* @param fieldName the name of the field.
* @return the index of the field with the specified name.
* @throws IllegalArgumentException if the given field name does not exist in the schema.
* @since 1.15.0
*/
def fieldIndex(fieldName: String): Int = {
fieldPositions.getOrElse(
fieldName,
throw new IllegalArgumentException("Field " + fieldName + " does not exist"))
}

protected[snowpark] def toAttributes: Seq[Attribute] = {
/*
* When user provided schema is used in a SnowflakePlan, we have to
Expand Down
44 changes: 44 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -602,4 +602,48 @@ public void getAsWithStructuredArray() {
},
getSession());
}

@Test
public void getAsWithFieldName() {
StructType schema =
StructType.create(
new StructField("EmpName", DataTypes.StringType),
new StructField("NumVal", DataTypes.IntegerType));

Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)};

DataFrame df = getSession().createDataFrame(data, schema);
Row row = df.collect()[0];

assert (row.getAs("EmpName", String.class) == row.getAs(0, String.class));
assert (row.getAs("EmpName", String.class).charAt(3) == 'd');
assert (row.getAs("NumVal", Integer.class) == row.getAs(1, Integer.class));

assert (row.getAs("EMPNAME", String.class) == row.getAs(0, String.class));

assertThrows(
IllegalArgumentException.class, () -> row.getAs("NonExistingColumn", Integer.class));

Row rowWithoutSchema = Row.create(40, "Alice");
assertThrows(
UnsupportedOperationException.class,
() -> rowWithoutSchema.getAs("NonExistingColumn", Integer.class));
}

@Test
public void fieldIndex() {
StructType schema =
StructType.create(
new StructField("EmpName", DataTypes.StringType),
new StructField("NumVal", DataTypes.IntegerType));

Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)};

DataFrame df = getSession().createDataFrame(data, schema);
Row row = df.collect()[0];

assert (row.fieldIndex("EmpName") == 0);
assert (row.fieldIndex("NumVal") == 1);
assertThrows(IllegalArgumentException.class, () -> row.fieldIndex("NonExistingColumn"));
}
}
Loading
Loading