diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java
index b75f081e..ce790653 100644
--- a/src/main/java/com/snowflake/snowpark_java/Functions.java
+++ b/src/main/java/com/snowflake/snowpark_java/Functions.java
@@ -1028,6 +1028,202 @@ public static Column pow(Column l, Column r) {
return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r.toScalaColumn()));
}
+ /**
+ * Returns a number (l) raised to the specified power (r).
+ *
+ *
Example:
+ *
+ *
{@code
+ * DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
+ * df.select(col("base"), col("exponent"), pow(col("base"), "exponent").as("result")).show();
+ *
+ * ----------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------------------
+ * |0.1 |2.0 |0.010000000000000002 |
+ * |2.0 |3.0 |8.0 |
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |-1.0 |0.5 |
+ * ----------------------------------------------
+ * }
+ *
+ * @param l The numeric column representing the base.
+ * @param r The name of the numeric column representing the exponent.
+ * @return A column containing the result of raising {@code l} to the power of {@code r}.
+ * @since 1.15.0
+ */
+ public static Column pow(Column l, String r) {
+ return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r));
+ }
+
+ /**
+ * Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ *
+ *
{@code
+ * DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
+ * df.select(col("base"), col("exponent"), pow("base", col("exponent")).as("result")).show();
+ *
+ * ----------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------------------
+ * |0.1 |2.0 |0.010000000000000002 |
+ * |2.0 |3.0 |8.0 |
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |-1.0 |0.5 |
+ * ----------------------------------------------
+ * }
+ *
+ * @param l The name of the numeric column representing the base.
+ * @param r The numeric column representing the exponent.
+ * @return A column containing the result of raising {@code l} to the power of {@code r}.
+ * @since 1.15.0
+ */
+ public static Column pow(String l, Column r) {
+ return new Column(com.snowflake.snowpark.functions.pow(l, r.toScalaColumn()));
+ }
+
+ /**
+ * Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ *
+ *
{@code
+ * DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
+ * df.select(col("base"), col("exponent"), pow("base", "exponent").as("result")).show();
+ *
+ * ----------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------------------
+ * |0.1 |2.0 |0.010000000000000002 |
+ * |2.0 |3.0 |8.0 |
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |-1.0 |0.5 |
+ * ----------------------------------------------
+ * }
+ *
+ * @param l The name of the numeric column representing the base.
+ * @param r The name of the numeric column representing the exponent.
+ * @return A column containing the result of raising {@code l} to the power of {@code r}.
+ * @since 1.15.0
+ */
+ public static Column pow(String l, String r) {
+ return new Column(com.snowflake.snowpark.functions.pow(l, r));
+ }
+
+ /**
+ * Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ *
+ *
{@code
+ * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)");
+ * df.select(col("base"), lit(2.0).as("exponent"), pow(col("base"), 2.0).as("result")).show();
+ *
+ * ----------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------
+ * |0.5 |2.0 |0.25 |
+ * |2.0 |2.0 |4.0 |
+ * |2.5 |2.0 |6.25 |
+ * |4.0 |2.0 |16.0 |
+ * ----------------------------------
+ * }
+ *
+ * @param l The numeric column representing the base.
+ * @param r The value of the exponent.
+ * @return A column containing the result of raising {@code l} to the power of {@code r}.
+ * @since 1.15.0
+ */
+ public static Column pow(Column l, Double r) {
+ return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r));
+ }
+
+ /**
+ * Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ *
+ *
{@code
+ * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)");
+ * df.select(col("base"), lit(2.0).as("exponent"), pow("base", 2.0).as("result")).show();
+ *
+ * ----------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------
+ * |0.5 |2.0 |0.25 |
+ * |2.0 |2.0 |4.0 |
+ * |2.5 |2.0 |6.25 |
+ * |4.0 |2.0 |16.0 |
+ * ----------------------------------
+ * }
+ *
+ * @param l The name of the numeric column representing the base.
+ * @param r The value of the exponent.
+ * @return A column containing the result of raising {@code l} to the power of {@code r}.
+ * @since 1.15.0
+ */
+ public static Column pow(String l, Double r) {
+ return new Column(com.snowflake.snowpark.functions.pow(l, r));
+ }
+
+ /**
+ * Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ *
+ *
{@code
+ * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)");
+ * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, col("exponent")).as("result")).show();
+ *
+ * --------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * --------------------------------------------
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |2.0 |4.0 |
+ * |2.0 |2.5 |5.656854249492381 |
+ * |2.0 |4.0 |16.0 |
+ * --------------------------------------------
+ * }
+ *
+ * @param l The value of the base.
+ * @param r The numeric column representing the exponent.
+ * @return A column containing the result of raising {@code l} to the power of {@code r}.
+ * @since 1.15.0
+ */
+ public static Column pow(Double l, Column r) {
+ return new Column(com.snowflake.snowpark.functions.pow(l, r.toScalaColumn()));
+ }
+
+ /**
+ * Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ *
+ *
{@code
+ * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)");
+ * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, "exponent").as("result")).show();
+ *
+ * --------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * --------------------------------------------
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |2.0 |4.0 |
+ * |2.0 |2.5 |5.656854249492381 |
+ * |2.0 |4.0 |16.0 |
+ * --------------------------------------------
+ * }
+ *
+ * @param l The value of the base.
+ * @param r The name of the numeric column representing the exponent.
+ * @return A column containing the result of raising {@code l} to the power of {@code r}.
+ * @since 1.15.0
+ */
+ public static Column pow(Double l, String r) {
+ return new Column(com.snowflake.snowpark.functions.pow(l, r));
+ }
+
/**
* Rounds the numeric values of the given column {@code e} to the {@code scale} decimal places
* using the half away from zero rounding mode.
@@ -4308,11 +4504,142 @@ public static Column from_unixtime(Column ut, String f) {
* [Row(SEQ8(0)=0),Row(SEQ8(0)=1), Row(SEQ8(0)=2)]
* }
*
+ * @return A sequence of monotonically increasing integers, with wrap-around * which happens after
+ * largest representable integer of integer width 8 byte.
* @since 1.15.0
*/
public static Column monotonically_increasing_id() {
return new Column(com.snowflake.snowpark.functions.monotonically_increasing_id());
}
+ /**
+ * Returns number of months between dates `start` and `end`.
+ *
+ * A whole number is returned if both inputs have the same day of month or both are the last
+ * day of their respective months. Otherwise, the difference is calculated assuming 31 days per
+ * month.
+ *
+ *
For example:
+ *
+ *
{@code
+ * {{{
+ * months_between("2017-11-14", "2017-07-14") // returns 4.0
+ * months_between("2017-01-01", "2017-01-10") // returns 0.29032258
+ * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5
+ * }}}
+ * }
+ *
+ * @param end A date, timestamp or string. If a string, the data must be in a format that can be
+ * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param start A date, timestamp or string. If a string, the data must be in a format that can
+ * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @return A double, or null if either `end` or `start` were strings that could not be cast to a
+ * timestamp. Negative if `end` is before `start`
+ * @since 1.15.0
+ */
+ public static Column months_between(String end, String start) {
+ return new Column(functions.months_between(end, start));
+ }
+
+ /**
+ * Locate the position of the first occurrence of substr column in the given string. Returns null
+ * if either of the arguments are null.
+ *
+ * Example
+ *
+ *
{@code
+ * SELECT id,
+ * string1,
+ * REGEXP_SUBSTR(string1, 'nevermore\\d') AS substring,
+ * REGEXP_INSTR( string1, 'nevermore\\d') AS position
+ * FROM demo1
+ * ORDER BY id;
+ *
+ * +----+-------------------------------------+------------+----------+
+ * | ID | STRING1 | SUBSTRING | POSITION |
+ * |----+-------------------------------------+------------+----------|
+ * | 1 | nevermore1, nevermore2, nevermore3. | nevermore1 | 1 |
+ * +----+-------------------------------------+------------+----------+
+ * }
+ *
+ * The position is not zero based, but 1 based index. Returns 0 if substr could not be found in
+ * str.
+ *
+ * @param str Column on which instr has to be applied
+ * @param substring Pattern to be retrieved
+ * @return A null if either of the arguments are null.
+ * @since 1.15.0
+ */
+ public static Column instr(Column str, String substring) {
+ return new Column(com.snowflake.snowpark.functions.instr(str.toScalaColumn(), substring));
+ }
+
+ /**
+ * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders
+ * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14
+ * 03:40:00.0'.
+ *
+ * For Example
+ *
+ *
{@code
+ * ALTER SESSION SET TIMEZONE = 'America/Los_Angeles';
+ * SELECT TO_TIMESTAMP_TZ('2024-04-05 01:02:03');
+ * +----------------------------------------+
+ * | TO_TIMESTAMP_TZ('2024-04-05 01:02:03') |
+ * |----------------------------------------|
+ * | 2024-04-05 01:02:03.000 -0700 |
+ * +----------------------------------------+
+ * }
+ *
+ * @param ts A date, timestamp or string. If a string, the data must be in a format that can be
+ * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing
+ * the time zone ID that the input should be adjusted to. It should be in the format of either
+ * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as
+ * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example
+ * '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short
+ * names are not recommended to use because they can be ambiguous.
+ * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz`
+ * was an invalid value
+ * @since 1.15.0
+ */
+ public static Column from_utc_timestamp(Column ts) {
+ return new Column(com.snowflake.snowpark.functions.from_utc_timestamp(ts.toScalaColumn()));
+ }
+
+ /**
+ * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time zone,
+ * and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14
+ * 01:40:00.0'.
+ *
+ * @param ts A date, timestamp or string. If a string, the data must be in a format that can be
+ * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing
+ * the time zone ID that the input should be adjusted to. It should be in the format of either
+ * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as
+ * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example
+ * '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short
+ * names are not recommended to use because they can be ambiguous.
+ * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz`
+ * was an invalid value
+ * @since 1.15.0
+ */
+ public static Column to_utc_timestamp(Column ts) {
+ return new Column(com.snowflake.snowpark.functions.to_utc_timestamp(ts.toScalaColumn()));
+ }
+
+ /**
+ * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with
+ * HALF_EVEN round mode, and returns the result as a string column.
+ *
+ * If d is 0, the result has no decimal point or fractional part. If d is less than 0, the
+ * result will be null.
+ *
+ * @param x numeric column to be transformed
+ * @param d Amount of decimal for the number format
+ * @return Number casted to the specific string format
+ * @since 1.15.0
+ */
+ public static Column format_number(Column x, Integer d) {
+ return new Column(com.snowflake.snowpark.functions.format_number(x.toScalaColumn(), d));
+ }
/* Returns a Column expression with values sorted in descending order.
*
diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java
index 0921a0d6..74475e2d 100644
--- a/src/main/java/com/snowflake/snowpark_java/Row.java
+++ b/src/main/java/com/snowflake/snowpark_java/Row.java
@@ -425,6 +425,64 @@ public Row getObject(int index) {
return (Row) get(index);
}
+ /**
+ * Returns the value at the specified column index and casts it to the desired type {@code T}.
+ *
+ *
Example:
+ *
+ *
{@code
+ * Row row = Row.create(1, "Alice", 95.5);
+ * row.getAs(0, Integer.class); // Returns 1 as an Int
+ * row.getAs(1, String.class); // Returns "Alice" as a String
+ * row.getAs(2, Double.class); // Returns 95.5 as a Double
+ * }
+ *
+ * @param index the zero-based column index within the row.
+ * @param clazz the {@code Class} object representing the type {@code T}.
+ * @param the expected type of the value at the specified column index.
+ * @return the value at the specified column index cast to type {@code T}.
+ * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}.
+ * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
+ * @since 1.15.0
+ */
+ @SuppressWarnings("unchecked")
+ public T getAs(int index, Class clazz)
+ throws ClassCastException, ArrayIndexOutOfBoundsException {
+ if (isNullAt(index)) {
+ return (T) get(index);
+ }
+
+ if (clazz == Byte.class) {
+ return (T) (Object) getByte(index);
+ }
+
+ if (clazz == Double.class) {
+ return (T) (Object) getDouble(index);
+ }
+
+ if (clazz == Float.class) {
+ return (T) (Object) getFloat(index);
+ }
+
+ if (clazz == Integer.class) {
+ return (T) (Object) getInt(index);
+ }
+
+ if (clazz == Long.class) {
+ return (T) (Object) getLong(index);
+ }
+
+ if (clazz == Short.class) {
+ return (T) (Object) getShort(index);
+ }
+
+ if (clazz == Variant.class) {
+ return (T) getVariant(index);
+ }
+
+ return (T) get(index);
+ }
+
/**
* Generates a string value to represent the content of this row.
*
diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala
index ac929f11..40ec4ffa 100644
--- a/src/main/scala/com/snowflake/snowpark/Row.scala
+++ b/src/main/scala/com/snowflake/snowpark/Row.scala
@@ -331,6 +331,43 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}
+ /** Returns the value at the specified column index and casts it to the desired type `T`.
+ *
+ * Example:
+ * {{{
+ * val row = Row(1, "Alice", 95.5)
+ * row.getAs[Int](0) // Returns 1 as an Int
+ * row.getAs[String](1) // Returns "Alice" as a String
+ * row.getAs[Double](2) // Returns 95.5 as a Double
+ * }}}
+ *
+ * @param index
+ * the zero-based column index within the row.
+ * @tparam T
+ * the expected type of the value at the specified column index.
+ * @return
+ * the value at the specified column index cast to type `T`.
+ * @throws ClassCastException
+ * if the value at the given index cannot be cast to type `T`.
+ * @throws ArrayIndexOutOfBoundsException
+ * if the column index is out of bounds.
+ * @group getter
+ * @since 1.15.0
+ */
+ def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = {
+ classTag.runtimeClass match {
+ case _ if isNullAt(index) => get(index).asInstanceOf[T]
+ case c if c == classOf[Byte] => getByte(index).asInstanceOf[T]
+ case c if c == classOf[Double] => getDouble(index).asInstanceOf[T]
+ case c if c == classOf[Float] => getFloat(index).asInstanceOf[T]
+ case c if c == classOf[Int] => getInt(index).asInstanceOf[T]
+ case c if c == classOf[Long] => getLong(index).asInstanceOf[T]
+ case c if c == classOf[Short] => getShort(index).asInstanceOf[T]
+ case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T]
+ case _ => get(index).asInstanceOf[T]
+ }
+ }
+
protected def convertValueToString(value: Any): String =
value match {
case null => "null"
@@ -362,10 +399,7 @@ class Row protected (values: Array[Any]) extends Serializable {
.map(convertValueToString)
.mkString("Row[", ",", "]")
- private def getAs[T](index: Int): T = get(index).asInstanceOf[T]
-
- private def getAnyValAs[T <: AnyVal](index: Int): T =
+ private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T =
if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null")
- else getAs[T](index)
-
+ else getAs[T](index)(classTag)
}
diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala
index e3dad8ec..69970c5f 100644
--- a/src/main/scala/com/snowflake/snowpark/functions.scala
+++ b/src/main/scala/com/snowflake/snowpark/functions.scala
@@ -800,6 +800,206 @@ object functions {
*/
def pow(l: Column, r: Column): Column = builtin("pow")(l, r)
+ /** Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ * {{{
+ * val df = session.sql(
+ * "select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)")
+ * df.select(col("base"), col("exponent"), pow(col("base"), "exponent").as("result")).show()
+ *
+ * ----------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------------------
+ * |0.1 |2.0 |0.010000000000000002 |
+ * |2.0 |3.0 |8.0 |
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |-1.0 |0.5 |
+ * ----------------------------------------------
+ * }}}
+ *
+ * @param l
+ * The numeric column representing the base.
+ * @param r
+ * The name of the numeric column representing the exponent.
+ * @return
+ * A column containing the result of raising `l` to the power of `r`.
+ * @group num_func
+ * @since 1.15.0
+ */
+ def pow(l: Column, r: String): Column = pow(l, col(r))
+
+ /** Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ * {{{
+ * val df = session.sql(
+ * "select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)")
+ * df.select(col("base"), col("exponent"), pow("base", col("exponent")).as("result")).show()
+ *
+ * ----------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------------------
+ * |0.1 |2.0 |0.010000000000000002 |
+ * |2.0 |3.0 |8.0 |
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |-1.0 |0.5 |
+ * ----------------------------------------------
+ * }}}
+ *
+ * @param l
+ * The name of the numeric column representing the base.
+ * @param r
+ * The numeric column representing the exponent.
+ * @return
+ * A column containing the result of raising `l` to the power of `r`.
+ * @group num_func
+ * @since 1.15.0
+ */
+ def pow(l: String, r: Column): Column = pow(col(l), r)
+
+ /** Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ * {{{
+ * val df = session.sql(
+ * "select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)")
+ * df.select(col("base"), col("exponent"), pow("base", "exponent").as("result")).show()
+ *
+ * ----------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------------------
+ * |0.1 |2.0 |0.010000000000000002 |
+ * |2.0 |3.0 |8.0 |
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |-1.0 |0.5 |
+ * ----------------------------------------------
+ * }}}
+ *
+ * @param l
+ * The name of the numeric column representing the base.
+ * @param r
+ * The name of the numeric column representing the exponent.
+ * @return
+ * A column containing the result of raising `l` to the power of `r`.
+ * @group num_func
+ * @since 1.15.0
+ */
+ def pow(l: String, r: String): Column = pow(col(l), col(r))
+
+ /** Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ * {{{
+ * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)")
+ * df.select(col("base"), lit(2.0).as("exponent"), pow(col("base"), 2.0).as("result")).show()
+ *
+ * ----------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------
+ * |0.5 |2.0 |0.25 |
+ * |2.0 |2.0 |4.0 |
+ * |2.5 |2.0 |6.25 |
+ * |4.0 |2.0 |16.0 |
+ * ----------------------------------
+ * }}}
+ *
+ * @param l
+ * The numeric column representing the base.
+ * @param r
+ * The value of the exponent.
+ * @return
+ * A column containing the result of raising `l` to the power of `r`.
+ * @group num_func
+ * @since 1.15.0
+ */
+ def pow(l: Column, r: Double): Column = pow(l, lit(r))
+
+ /** Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ * {{{
+ * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)")
+ * df.select(col("base"), lit(2.0).as("exponent"), pow("base", 2.0).as("result")).show()
+ *
+ * ----------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * ----------------------------------
+ * |0.5 |2.0 |0.25 |
+ * |2.0 |2.0 |4.0 |
+ * |2.5 |2.0 |6.25 |
+ * |4.0 |2.0 |16.0 |
+ * ----------------------------------
+ * }}}
+ *
+ * @param l
+ * The name of the numeric column representing the base.
+ * @param r
+ * The value of the exponent.
+ * @return
+ * A column containing the result of raising `l` to the power of `r`.
+ * @group num_func
+ * @since 1.15.0
+ */
+ def pow(l: String, r: Double): Column = pow(col(l), r)
+
+ /** Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ * {{{
+ * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)")
+ * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, col("exponent")).as("result"))
+ * .show()
+ *
+ * --------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * --------------------------------------------
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |2.0 |4.0 |
+ * |2.0 |2.5 |5.656854249492381 |
+ * |2.0 |4.0 |16.0 |
+ * --------------------------------------------
+ * }}}
+ *
+ * @param l
+ * The value of the base.
+ * @param r
+ * The numeric column representing the exponent.
+ * @return
+ * A column containing the result of raising `l` to the power of `r`.
+ * @group num_func
+ * @since 1.15.0
+ */
+ def pow(l: Double, r: Column): Column = pow(lit(l), r)
+
+ /** Returns a number (l) raised to the specified power (r).
+ *
+ * Example:
+ * {{{
+ * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)")
+ * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, "exponent").as("result")).show()
+ *
+ * --------------------------------------------
+ * |"BASE" |"EXPONENT" |"RESULT" |
+ * --------------------------------------------
+ * |2.0 |0.5 |1.4142135623730951 |
+ * |2.0 |2.0 |4.0 |
+ * |2.0 |2.5 |5.656854249492381 |
+ * |2.0 |4.0 |16.0 |
+ * --------------------------------------------
+ * }}}
+ *
+ * @param l
+ * The value of the base.
+ * @param r
+ * The name of the numeric column representing the exponent.
+ * @return
+ * A column containing the result of raising `l` to the power of `r`.
+ * @group num_func
+ * @since 1.15.0
+ */
+ def pow(l: Double, r: String): Column = pow(l, col(r))
+
/** Rounds the numeric values of the given column `e` to the `scale` decimal places using the half
* away from zero rounding mode.
*
@@ -3255,6 +3455,111 @@ object functions {
*/
def monotonically_increasing_id(): Column = builtin("seq8")()
+ /** Returns number of months between dates `start` and `end`.
+ *
+ * A whole number is returned if both inputs have the same day of month or both are the last day
+ * of their respective months. Otherwise, the difference is calculated assuming 31 days per
+ * month.
+ *
+ * For example:
+ * {{{
+ * months_between("2017-11-14", "2017-07-14") // returns 4.0
+ * months_between("2017-01-01", "2017-01-10") // returns 0.29032258
+ * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5
+ * }}}
+ * @since 1.15.0
+ * @param end
+ * Column name. If a string, the data must be in a format that can be cast to a timestamp, such
+ * as yyyy-MM-dd or yyyy-MM-dd HH:mm:ss.SSSS
+ * @param start
+ * Column name . If a string, the data must be in a format that can cast to a timestamp, such
+ * as yyyy-MM-dd or yyyy-MM-dd HH:mm:ss.SSSS
+ * @return
+ * A double, or null if either end or start were strings that could not be cast to a timestamp.
+ * Negative if end is before start
+ */
+ def months_between(end: String, start: String): Column =
+ builtin("MONTHS_BETWEEN")(col(end), col(start))
+
+ /** Locate the position of the first occurrence of substr column in the given string. Returns null
+ * if either of the arguments are null. For example SELECT id, string1, REGEXP_SUBSTR(string1,
+ * 'nevermore\\d') AS substring, REGEXP_INSTR( string1, 'nevermore\\d') AS position FROM demo1
+ * ORDER BY id;
+ * | ID | STRING1 | SUBSTRING | POSITION |
+ * |:-------------------------------------------------------------------|:------------------------------------|:-----------|:---------|
+ * | ----+-------------------------------------+------------+---------- | | | |
+ * | 1 | nevermore1, nevermore2, nevermore3. | nevermore1 | 1 |
+ *
+ * @since 1.15.0
+ * @note
+ * The position is not zero based, but 1 based index. Returns 0 if substr could not be found in
+ * str.
+ */
+ def instr(str: Column, substring: String): Column = builtin("REGEXP_INSTR")(str, substring)
+
+ /** Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders
+ * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14
+ * 03:40:00.0'. ALTER SESSION SET TIMEZONE = 'America/Los_Angeles'; SELECT
+ * TO_TIMESTAMP_TZ('2024-04-05 01:02:03');
+ * | TO_TIMESTAMP_TZ('2024-04-05 01:02:03') |
+ * |:---------------------------------------|
+ * | 2024-04-05 01:02:03.000 -0700 |
+ *
+ * @since 1.15.0
+ * @param ts
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to a
+ * timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing the time
+ * zone ID that the input should be adjusted to. It should be in the format of either
+ * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as
+ * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example '-08:00'
+ * or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short names are
+ * not recommended to use because they can be ambiguous.
+ * @return
+ * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was
+ * an invalid value
+ */
+ def from_utc_timestamp(ts: Column): Column =
+ builtin("TO_TIMESTAMP_TZ")(ts)
+
+ /** Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time
+ * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield
+ * '2017-07-14 01:40:00.0'.
+ * @since 1.15.0
+ * @param ts
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to a
+ * timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing the time
+ * zone ID that the input should be adjusted to. It should be in the format of either
+ * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as
+ * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example '-08:00'
+ * or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short names are
+ * not recommended to use because they can be ambiguous.
+ * @return
+ * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was
+ * an invalid value
+ */
+ def to_utc_timestamp(ts: Column): Column = builtin("TO_TIMESTAMP_TZ")(ts)
+
+ /** Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with
+ * HALF_EVEN round mode, and returns the result as a string column.
+ * @since 1.15.0
+ * If d is 0, the result has no decimal point or fractional part. If d is less than 0, the
+ * result will be null.
+ *
+ * @param x
+ * numeric column to be transformed
+ * @param d
+ * Amount of decimal for the number format
+ *
+ * @return
+ * Number casted to the specific string format
+ */
+ def format_number(x: Column, d: Int): Column = {
+ if (d < 0) {
+ lit(null)
+ } else {
+ builtin("TO_VARCHAR")(x, if (d > 0) s"999,999.${"0" * d}" else "999,999")
+ }
+ }
/* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
index 0d2d2b30..89e9cbb5 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
@@ -593,11 +593,46 @@ public void pow() {
DataFrame df =
getSession().sql("select * from values(0.1, 0.5),(0.2, 0.6),(0.3, 0.7) as T(a,b)");
Row[] expected = {
- Row.create(0.31622776601683794),
- Row.create(0.3807307877431757),
- Row.create(0.4305116202499342)
+ Row.create(
+ 0.31622776601683794,
+ 0.31622776601683794,
+ 0.31622776601683794,
+ 0.31622776601683794,
+ 0.15848931924611134,
+ 0.15848931924611134,
+ 0.6324555320336759,
+ 0.6324555320336759),
+ Row.create(
+ 0.3807307877431757,
+ 0.3807307877431757,
+ 0.3807307877431757,
+ 0.3807307877431757,
+ 0.27594593229224296,
+ 0.27594593229224296,
+ 0.5770799623628855,
+ 0.5770799623628855),
+ Row.create(
+ 0.4305116202499342,
+ 0.4305116202499342,
+ 0.4305116202499342,
+ 0.4305116202499342,
+ 0.3816778909618176,
+ 0.3816778909618176,
+ 0.526552881733695,
+ 0.526552881733695)
};
- checkAnswer(df.select(Functions.pow(df.col("a"), df.col("b"))), expected, false);
+ checkAnswer(
+ df.select(
+ Functions.pow(df.col("a"), df.col("b")),
+ Functions.pow(df.col("a"), "b"),
+ Functions.pow("a", df.col("b")),
+ Functions.pow("a", "b"),
+ Functions.pow(df.col("a"), 0.8),
+ Functions.pow("a", 0.8),
+ Functions.pow(0.4, df.col("b")),
+ Functions.pow(0.4, "b")),
+ expected,
+ false);
}
@Test
@@ -3100,4 +3135,55 @@ public void unhex() {
Row[] expected = {Row.create("1"), Row.create("2"), Row.create("3")};
checkAnswer(df.select(Functions.unhex(Functions.col("a"))), expected, false);
}
+
+ @Test
+ public void months_between() {
+ DataFrame df =
+ getSession()
+ .sql(
+ "select * from values('2010-07-02'::Date,'2010-08-02'::Date), "
+ + "('2020-08-02'::Date,'2020-12-02'::Date) as t(a,b)");
+ Row[] expected = {Row.create(1.000000), Row.create(4.000000)};
+ checkAnswer(df.select(Functions.months_between("b", "a")), expected, false);
+ }
+
+ @Test
+ public void instr() {
+ DataFrame df =
+ getSession()
+ .sql(
+ "select * from values('It was the best of times, it was the worst of times') as t(a)");
+ Row[] expected = {Row.create(4)};
+ checkAnswer(df.select(Functions.instr(df.col("a"), "was")), expected, false);
+ }
+
+ @Test
+ public void format_number1() {
+ DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
+ Row[] expected = {Row.create("1"), Row.create("2"), Row.create("3")};
+ checkAnswer(
+ df.select(Functions.ltrim(Functions.format_number(df.col("a"), 0))), expected, false);
+ }
+
+ @Test
+ public void format_number2() {
+ DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
+ Row[] expected = {Row.create("1.00"), Row.create("2.00"), Row.create("3.00")};
+ checkAnswer(
+ df.select(Functions.ltrim(Functions.format_number(df.col("a"), 2))), expected, false);
+ }
+
+ @Test
+ public void from_utc_timestamp() {
+ DataFrame df = getSession().sql("select * from values('2024-04-05 01:02:03') as t(a)");
+ Row[] expected = {Row.create(Timestamp.valueOf("2024-04-05 01:02:03.0"))};
+ checkAnswer(df.select(Functions.from_utc_timestamp(df.col("a"))), expected, false);
+ }
+
+ @Test
+ public void to_utc_timestamp() {
+ DataFrame df = getSession().sql("select * from values('2024-04-05 01:02:03') as t(a)");
+ Row[] expected = {Row.create(Timestamp.valueOf("2024-04-05 01:02:03.0"))};
+ checkAnswer(df.select(Functions.to_utc_timestamp(df.col("a"))), expected, false);
+ }
}
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
index 0570c421..f8918292 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
@@ -1,5 +1,7 @@
package com.snowflake.snowpark_test;
+import static org.junit.Assert.assertThrows;
+
import com.snowflake.snowpark_java.DataFrame;
import com.snowflake.snowpark_java.Row;
import com.snowflake.snowpark_java.types.*;
@@ -7,10 +9,13 @@
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.TimeZone;
import org.junit.Test;
public class JavaRowSuite extends TestBase {
@@ -429,4 +434,172 @@ public void testGetRow() {
},
getSession());
}
+
+ @Test
+ public void getAs() {
+ long milliseconds = System.currentTimeMillis();
+
+ StructType schema =
+ StructType.create(
+ new StructField("c01", DataTypes.BinaryType),
+ new StructField("c02", DataTypes.BooleanType),
+ new StructField("c03", DataTypes.ByteType),
+ new StructField("c04", DataTypes.DateType),
+ new StructField("c05", DataTypes.DoubleType),
+ new StructField("c06", DataTypes.FloatType),
+ new StructField("c07", DataTypes.GeographyType),
+ new StructField("c08", DataTypes.GeometryType),
+ new StructField("c09", DataTypes.IntegerType),
+ new StructField("c10", DataTypes.LongType),
+ new StructField("c11", DataTypes.ShortType),
+ new StructField("c12", DataTypes.StringType),
+ new StructField("c13", DataTypes.TimeType),
+ new StructField("c14", DataTypes.TimestampType),
+ new StructField("c15", DataTypes.VariantType));
+
+ Row[] data = {
+ Row.create(
+ new byte[] {1, 2},
+ true,
+ Byte.MIN_VALUE,
+ Date.valueOf("2024-01-01"),
+ Double.MIN_VALUE,
+ Float.MIN_VALUE,
+ Geography.fromGeoJSON("POINT(30 10)"),
+ Geometry.fromGeoJSON("POINT(20 40)"),
+ Integer.MIN_VALUE,
+ Long.MIN_VALUE,
+ Short.MIN_VALUE,
+ "string",
+ Time.valueOf("16:23:04"),
+ new Timestamp(milliseconds),
+ new Variant(1))
+ };
+
+ DataFrame df = getSession().createDataFrame(data, schema);
+ Row row = df.collect()[0];
+
+ assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2});
+ assert row.getAs(1, Boolean.class);
+ assert row.getAs(2, Byte.class) == Byte.MIN_VALUE;
+ assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01"));
+ assert row.getAs(4, Double.class) == Double.MIN_VALUE;
+ assert row.getAs(5, Float.class) == Float.MIN_VALUE;
+ assert row.getAs(6, Geography.class)
+ .equals(
+ Geography.fromGeoJSON(
+ "{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}"));
+ assert row.getAs(7, Geometry.class)
+ .equals(
+ Geometry.fromGeoJSON(
+ "{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}"));
+ assert row.getAs(8, Integer.class) == Integer.MIN_VALUE;
+ assert row.getAs(9, Long.class) == Long.MIN_VALUE;
+ assert row.getAs(10, Short.class) == Short.MIN_VALUE;
+ assert row.getAs(11, String.class).equals("string");
+ assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04"));
+ assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds));
+ assert row.getAs(14, Variant.class).equals(new Variant(1));
+
+ Row finalRow = row;
+ assertThrows(
+ ClassCastException.class,
+ () -> {
+ Boolean b = finalRow.getAs(0, Boolean.class);
+ });
+ assertThrows(ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class));
+
+ data =
+ new Row[] {
+ Row.create(
+ null, null, null, null, null, null, null, null, null, null, null, null, null, null,
+ null)
+ };
+
+ df = getSession().createDataFrame(data, schema);
+ row = df.collect()[0];
+
+ assert row.getAs(0, byte[].class) == null;
+ assert row.getAs(1, Boolean.class) == null;
+ assert row.getAs(2, Byte.class) == null;
+ assert row.getAs(3, Date.class) == null;
+ assert row.getAs(4, Double.class) == null;
+ assert row.getAs(5, Float.class) == null;
+ assert row.getAs(6, Geography.class) == null;
+ assert row.getAs(7, Geometry.class) == null;
+ assert row.getAs(8, Integer.class) == null;
+ assert row.getAs(9, Long.class) == null;
+ assert row.getAs(10, Short.class) == null;
+ assert row.getAs(11, String.class) == null;
+ assert row.getAs(12, Time.class) == null;
+ assert row.getAs(13, Timestamp.class) == null;
+ assert row.getAs(14, Variant.class) == null;
+ }
+
+ @Test
+ public void getAsWithStructuredMap() {
+ structuredTypeTest(
+ () -> {
+ String query =
+ "SELECT "
+ + "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,"
+ + "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,"
+ + "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3";
+
+ DataFrame df = getSession().sql(query);
+ Row row = df.collect()[0];
+
+ Map, ?> map1 = row.getAs(0, Map.class);
+ assert (Long) map1.get("a") == 1L;
+ assert (Long) map1.get("b") == 2L;
+
+ Map, ?> map2 = row.getAs(1, Map.class);
+ assert map2.get(1L).equals("a");
+ assert map2.get(2L).equals("b");
+
+ Map, ?> map3 = row.getAs(2, Map.class);
+ Map map3ExpectedInnerMap = new HashMap<>();
+ map3ExpectedInnerMap.put("a", 1L);
+ map3ExpectedInnerMap.put("b", 2L);
+ assert map3.get(1L).equals(map3ExpectedInnerMap);
+ assert map3.get(2L).equals(Collections.singletonMap("c", 3L));
+ },
+ getSession());
+ }
+
+ @Test
+ public void getAsWithStructuredArray() {
+ structuredTypeTest(
+ () -> {
+ TimeZone oldTimeZone = TimeZone.getDefault();
+ try {
+ TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"));
+
+ String query =
+ "SELECT "
+ + "[1,2,3]::ARRAY(NUMBER) AS arr1,"
+ + "['a','b']::ARRAY(VARCHAR) AS arr2,"
+ + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,"
+ + "[[1,2]]::ARRAY(ARRAY) AS arr4";
+
+ DataFrame df = getSession().sql(query);
+ Row row = df.collect()[0];
+
+ ArrayList> array1 = row.getAs(0, ArrayList.class);
+ assert array1.equals(Arrays.asList(1L, 2L, 3L));
+
+ ArrayList> array2 = row.getAs(1, ArrayList.class);
+ assert array2.equals(Arrays.asList("a", "b"));
+
+ ArrayList> array3 = row.getAs(2, ArrayList.class);
+ assert array3.equals(Collections.singletonList(new Timestamp(31000000000L)));
+
+ ArrayList> array4 = row.getAs(3, ArrayList.class);
+ assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]"));
+ } finally {
+ TimeZone.setDefault(oldTimeZone);
+ }
+ },
+ getSession());
+ }
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
index 546164fc..3ac7996b 100644
--- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
@@ -353,10 +353,25 @@ trait FunctionSuite extends TestData {
test("pow") {
checkAnswer(
- double2.select(pow(col("A"), col("B"))),
- Seq(Row(0.31622776601683794), Row(0.3807307877431757), Row(0.4305116202499342)),
+ double2.select(
+ pow(col("A"), col("B")),
+ pow(col("A"), "B"),
+ pow("A", col("B")),
+ pow("A", "B"),
+ pow(col("A"), 0.8),
+ pow("A", 0.8),
+ pow(0.4, col("B")),
+ pow(0.4, "B")),
+ Seq(
+ Row(0.31622776601683794, 0.31622776601683794, 0.31622776601683794, 0.31622776601683794,
+ 0.15848931924611134, 0.15848931924611134, 0.6324555320336759, 0.6324555320336759),
+ Row(0.3807307877431757, 0.3807307877431757, 0.3807307877431757, 0.3807307877431757,
+ 0.27594593229224296, 0.27594593229224296, 0.5770799623628855, 0.5770799623628855),
+ Row(0.4305116202499342, 0.4305116202499342, 0.4305116202499342, 0.4305116202499342,
+ 0.3816778909618176, 0.3816778909618176, 0.526552881733695, 0.526552881733695)),
sort = false)
}
+
test("shiftleft shiftright") {
checkAnswer(
integer1.select(bitshiftleft(col("A"), lit(1)), bitshiftright(col("A"), lit(1))),
@@ -2455,6 +2470,58 @@ trait FunctionSuite extends TestData {
Seq(Row("1"), Row("2"), Row("3")),
sort = false)
}
+ test("months_between") {
+ val months_between = functions.builtin("MONTHS_BETWEEN")
+ val input = Seq(
+ (Date.valueOf("2010-08-02"), Date.valueOf("2010-07-02")),
+ (Date.valueOf("2020-12-02"), Date.valueOf("2020-08-02")))
+ .toDF("a", "b")
+ checkAnswer(
+ input.select(months_between(col("a"), col("b"))),
+ Seq(Row((1.000000)), Row(4.000000)),
+ sort = false)
+ }
+
+ test("instr") {
+ val df = Seq("It was the best of times, it was the worst of times").toDF("a")
+ checkAnswer(df.select(instr(col("a"), "was")), Seq(Row(4)), sort = false)
+ }
+
+ test("format_number1") {
+
+ checkAnswer(
+ number3.select(ltrim(format_number(col("a"), 0))),
+ Seq(Row(("1")), Row(("2")), Row(("3"))),
+ sort = false)
+ }
+
+ test("format_number2") {
+
+ checkAnswer(
+ number3.select(ltrim(format_number(col("a"), 2))),
+ Seq(Row(("1.00")), Row(("2.00")), Row(("3.00"))),
+ sort = false)
+ }
+
+ test("format_number3") {
+
+ checkAnswer(
+ number3.select(ltrim(format_number(col("a"), -1))),
+ Seq(Row((null)), Row((null)), Row((null))),
+ sort = false)
+ }
+
+ test("from_utc_timestamp") {
+ val expected = Seq(Timestamp.valueOf("2024-04-05 01:02:03.0")).toDF("a")
+ val data = Seq("2024-04-05 01:02:03").toDF("a")
+ checkAnswer(data.select(from_utc_timestamp(col("a"))), expected, sort = false)
+ }
+
+ test("to_utc_timestamp") {
+ val expected = Seq(Timestamp.valueOf("2024-04-05 01:02:03.0")).toDF("a")
+ val data = Seq("2024-04-05 01:02:03").toDF("a")
+ checkAnswer(data.select(to_utc_timestamp(col("a"))), expected, sort = false)
+ }
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
index 54aba687..8de21e92 100644
--- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
@@ -1,11 +1,12 @@
package com.snowflake.snowpark_test
-import com.snowflake.snowpark.types.{Geography, Geometry, Variant}
+import com.snowflake.snowpark.types._
import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException}
-import java.sql.{Date, Timestamp}
+import java.sql.{Date, Time, Timestamp}
import java.time.{Instant, LocalDate}
import java.util
+import java.util.TimeZone
class RowSuite extends SNTestBase {
@@ -233,6 +234,161 @@ class RowSuite extends SNTestBase {
assert(err.message.matches(msg))
}
+ test("getAs") {
+ val milliseconds = System.currentTimeMillis()
+
+ val schema = StructType(
+ Seq(
+ StructField("c01", BinaryType),
+ StructField("c02", BooleanType),
+ StructField("c03", ByteType),
+ StructField("c04", DateType),
+ StructField("c05", DoubleType),
+ StructField("c06", FloatType),
+ StructField("c07", GeographyType),
+ StructField("c08", GeometryType),
+ StructField("c09", IntegerType),
+ StructField("c10", LongType),
+ StructField("c11", ShortType),
+ StructField("c12", StringType),
+ StructField("c13", TimeType),
+ StructField("c14", TimestampType),
+ StructField("c15", VariantType)))
+
+ var data = Seq(
+ Row(
+ Array[Byte](1, 9),
+ true,
+ Byte.MinValue,
+ Date.valueOf("2024-01-01"),
+ Double.MinValue,
+ Float.MinPositiveValue,
+ Geography.fromGeoJSON("POINT(30 10)"),
+ Geometry.fromGeoJSON("POINT(20 40)"),
+ Int.MinValue,
+ Long.MinValue,
+ Short.MinValue,
+ "string",
+ Time.valueOf("16:23:04"),
+ new Timestamp(milliseconds),
+ new Variant(1)))
+
+ var df = session.createDataFrame(data, schema)
+ var row = df.collect()(0)
+
+ assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9))
+ assert(row.getAs[Boolean](1))
+ assert(row.getAs[Byte](2) == Byte.MinValue)
+ assert(row.getAs[Date](3) == Date.valueOf("2024-01-01"))
+ assert(row.getAs[Double](4) == Double.MinValue)
+ assert(row.getAs[Float](5) == Float.MinPositiveValue)
+ assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{
+ | "coordinates": [
+ | 30,
+ | 10
+ | ],
+ | "type": "Point"
+ |}""".stripMargin))
+ assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{
+ | "coordinates": [
+ | 2.000000000000000e+01,
+ | 4.000000000000000e+01
+ | ],
+ | "type": "Point"
+ |}""".stripMargin))
+ assert(row.getAs[Int](8) == Int.MinValue)
+ assert(row.getAs[Long](9) == Long.MinValue)
+ assert(row.getAs[Short](10) == Short.MinValue)
+ assert(row.getAs[String](11) == "string")
+ assert(row.getAs[Time](12) == Time.valueOf("16:23:04"))
+ assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds))
+ assert(row.getAs[Variant](14) == new Variant(1))
+ assertThrows[ClassCastException](row.getAs[Boolean](0))
+ assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1))
+
+ data = Seq(
+ Row(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null))
+
+ df = session.createDataFrame(data, schema)
+ row = df.collect()(0)
+
+ assert(row.getAs[Array[Byte]](0) == null)
+ assert(!row.getAs[Boolean](1))
+ assert(row.getAs[Byte](2) == 0)
+ assert(row.getAs[Date](3) == null)
+ assert(row.getAs[Double](4) == 0)
+ assert(row.getAs[Float](5) == 0)
+ assert(row.getAs[Geography](6) == null)
+ assert(row.getAs[Geometry](7) == null)
+ assert(row.getAs[Int](8) == 0)
+ assert(row.getAs[Long](9) == 0)
+ assert(row.getAs[Short](10) == 0)
+ assert(row.getAs[String](11) == null)
+ assert(row.getAs[Time](12) == null)
+ assert(row.getAs[Timestamp](13) == null)
+ assert(row.getAs[Variant](14) == null)
+ }
+
+ test("getAs with structured map") {
+ structuredTypeTest {
+ val query =
+ """SELECT
+ | {'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,
+ | {'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,
+ | {'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3
+ |""".stripMargin
+
+ val df = session.sql(query)
+ val row = df.collect()(0)
+
+ val map1 = row.getAs[Map[String, Long]](0)
+ assert(map1("a") == 1L)
+ assert(map1("b") == 2L)
+
+ val map2 = row.getAs[Map[Long, String]](1)
+ assert(map2(1) == "a")
+ assert(map2(2) == "b")
+
+ val map3 = row.getAs[Map[Long, Map[String, Long]]](2)
+ assert(map3(1) == Map("a" -> 1, "b" -> 2))
+ assert(map3(2) == Map("c" -> 3))
+ }
+ }
+
+ test("getAs with structured array") {
+ structuredTypeTest {
+ val oldTimeZone = TimeZone.getDefault
+ try {
+ TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"))
+
+ val query =
+ """SELECT
+ | [1,2,3]::ARRAY(NUMBER) AS arr1,
+ | ['a','b']::ARRAY(VARCHAR) AS arr2,
+ | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,
+ | [[1,2]]::ARRAY(ARRAY) AS arr4
+ |""".stripMargin
+
+ val df = session.sql(query)
+ val row = df.collect()(0)
+
+ val array1 = row.getAs[Array[Object]](0)
+ assert(array1 sameElements Array(1, 2, 3))
+
+ val array2 = row.getAs[Array[Object]](1)
+ assert(array2 sameElements Array("a", "b"))
+
+ val array3 = row.getAs[Array[Object]](2)
+ assert(array3 sameElements Array(new Timestamp(31000000000L)))
+
+ val array4 = row.getAs[Array[Object]](3)
+ assert(array4 sameElements Array("[\n 1,\n 2\n]"))
+ } finally {
+ TimeZone.setDefault(oldTimeZone)
+ }
+ }
+ }
+
test("hashCode") {
val row1 = Row(1, 2, 3)
val row2 = Row("str", null, 3)