diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index b75f081e..ce790653 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -1028,6 +1028,202 @@ public static Column pow(Column l, Column r) { return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r.toScalaColumn())); } + /** + * Returns a number (l) raised to the specified power (r). + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
+   * df.select(col("base"), col("exponent"), pow(col("base"), "exponent").as("result")).show();
+   *
+   * ----------------------------------------------
+   * |"BASE"  |"EXPONENT"  |"RESULT"              |
+   * ----------------------------------------------
+   * |0.1     |2.0         |0.010000000000000002  |
+   * |2.0     |3.0         |8.0                   |
+   * |2.0     |0.5         |1.4142135623730951    |
+   * |2.0     |-1.0        |0.5                   |
+   * ----------------------------------------------
+   * }
+ * + * @param l The numeric column representing the base. + * @param r The name of the numeric column representing the exponent. + * @return A column containing the result of raising {@code l} to the power of {@code r}. + * @since 1.15.0 + */ + public static Column pow(Column l, String r) { + return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r)); + } + + /** + * Returns a number (l) raised to the specified power (r). + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
+   * df.select(col("base"), col("exponent"), pow("base", col("exponent")).as("result")).show();
+   *
+   * ----------------------------------------------
+   * |"BASE"  |"EXPONENT"  |"RESULT"              |
+   * ----------------------------------------------
+   * |0.1     |2.0         |0.010000000000000002  |
+   * |2.0     |3.0         |8.0                   |
+   * |2.0     |0.5         |1.4142135623730951    |
+   * |2.0     |-1.0        |0.5                   |
+   * ----------------------------------------------
+   * }
+ * + * @param l The name of the numeric column representing the base. + * @param r The numeric column representing the exponent. + * @return A column containing the result of raising {@code l} to the power of {@code r}. + * @since 1.15.0 + */ + public static Column pow(String l, Column r) { + return new Column(com.snowflake.snowpark.functions.pow(l, r.toScalaColumn())); + } + + /** + * Returns a number (l) raised to the specified power (r). + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
+   * df.select(col("base"), col("exponent"), pow("base", "exponent").as("result")).show();
+   *
+   * ----------------------------------------------
+   * |"BASE"  |"EXPONENT"  |"RESULT"              |
+   * ----------------------------------------------
+   * |0.1     |2.0         |0.010000000000000002  |
+   * |2.0     |3.0         |8.0                   |
+   * |2.0     |0.5         |1.4142135623730951    |
+   * |2.0     |-1.0        |0.5                   |
+   * ----------------------------------------------
+   * }
+ * + * @param l The name of the numeric column representing the base. + * @param r The name of the numeric column representing the exponent. + * @return A column containing the result of raising {@code l} to the power of {@code r}. + * @since 1.15.0 + */ + public static Column pow(String l, String r) { + return new Column(com.snowflake.snowpark.functions.pow(l, r)); + } + + /** + * Returns a number (l) raised to the specified power (r). + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)");
+   * df.select(col("base"), lit(2.0).as("exponent"), pow(col("base"), 2.0).as("result")).show();
+   *
+   * ----------------------------------
+   * |"BASE"  |"EXPONENT"  |"RESULT"  |
+   * ----------------------------------
+   * |0.5     |2.0         |0.25      |
+   * |2.0     |2.0         |4.0       |
+   * |2.5     |2.0         |6.25      |
+   * |4.0     |2.0         |16.0      |
+   * ----------------------------------
+   * }
+ * + * @param l The numeric column representing the base. + * @param r The value of the exponent. + * @return A column containing the result of raising {@code l} to the power of {@code r}. + * @since 1.15.0 + */ + public static Column pow(Column l, Double r) { + return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r)); + } + + /** + * Returns a number (l) raised to the specified power (r). + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)");
+   * df.select(col("base"), lit(2.0).as("exponent"), pow("base", 2.0).as("result")).show();
+   *
+   * ----------------------------------
+   * |"BASE"  |"EXPONENT"  |"RESULT"  |
+   * ----------------------------------
+   * |0.5     |2.0         |0.25      |
+   * |2.0     |2.0         |4.0       |
+   * |2.5     |2.0         |6.25      |
+   * |4.0     |2.0         |16.0      |
+   * ----------------------------------
+   * }
+ * + * @param l The name of the numeric column representing the base. + * @param r The value of the exponent. + * @return A column containing the result of raising {@code l} to the power of {@code r}. + * @since 1.15.0 + */ + public static Column pow(String l, Double r) { + return new Column(com.snowflake.snowpark.functions.pow(l, r)); + } + + /** + * Returns a number (l) raised to the specified power (r). + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)");
+   * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, col("exponent")).as("result")).show();
+   *
+   * --------------------------------------------
+   * |"BASE"  |"EXPONENT"  |"RESULT"            |
+   * --------------------------------------------
+   * |2.0     |0.5         |1.4142135623730951  |
+   * |2.0     |2.0         |4.0                 |
+   * |2.0     |2.5         |5.656854249492381   |
+   * |2.0     |4.0         |16.0                |
+   * --------------------------------------------
+   * }
+ * + * @param l The value of the base. + * @param r The numeric column representing the exponent. + * @return A column containing the result of raising {@code l} to the power of {@code r}. + * @since 1.15.0 + */ + public static Column pow(Double l, Column r) { + return new Column(com.snowflake.snowpark.functions.pow(l, r.toScalaColumn())); + } + + /** + * Returns a number (l) raised to the specified power (r). + * + *

Example: + * + *

{@code
+   * DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)");
+   * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, "exponent").as("result")).show();
+   *
+   * --------------------------------------------
+   * |"BASE"  |"EXPONENT"  |"RESULT"            |
+   * --------------------------------------------
+   * |2.0     |0.5         |1.4142135623730951  |
+   * |2.0     |2.0         |4.0                 |
+   * |2.0     |2.5         |5.656854249492381   |
+   * |2.0     |4.0         |16.0                |
+   * --------------------------------------------
+   * }
+ * + * @param l The value of the base. + * @param r The name of the numeric column representing the exponent. + * @return A column containing the result of raising {@code l} to the power of {@code r}. + * @since 1.15.0 + */ + public static Column pow(Double l, String r) { + return new Column(com.snowflake.snowpark.functions.pow(l, r)); + } + /** * Rounds the numeric values of the given column {@code e} to the {@code scale} decimal places * using the half away from zero rounding mode. @@ -4308,11 +4504,142 @@ public static Column from_unixtime(Column ut, String f) { * [Row(SEQ8(0)=0),Row(SEQ8(0)=1), Row(SEQ8(0)=2)] * } * + * @return A sequence of monotonically increasing integers, with wrap-around * which happens after + * largest representable integer of integer width 8 byte. * @since 1.15.0 */ public static Column monotonically_increasing_id() { return new Column(com.snowflake.snowpark.functions.monotonically_increasing_id()); } + /** + * Returns number of months between dates `start` and `end`. + * + *

A whole number is returned if both inputs have the same day of month or both are the last + * day of their respective months. Otherwise, the difference is calculated assuming 31 days per + * month. + * + *

For example: + * + *

{@code
+   * {{{
+   * months_between("2017-11-14", "2017-07-14")  // returns 4.0
+   * months_between("2017-01-01", "2017-01-10")  // returns 0.29032258
+   * months_between("2017-06-01", "2017-06-16 12:00:00")  // returns -0.5
+   * }}}
+   * }
+ * + * @param end A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @param start A date, timestamp or string. If a string, the data must be in a format that can + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` + * @return A double, or null if either `end` or `start` were strings that could not be cast to a + * timestamp. Negative if `end` is before `start` + * @since 1.15.0 + */ + public static Column months_between(String end, String start) { + return new Column(functions.months_between(end, start)); + } + + /** + * Locate the position of the first occurrence of substr column in the given string. Returns null + * if either of the arguments are null. + * + *

Example + * + *

{@code
+   * SELECT id,
+   *        string1,
+   *        REGEXP_SUBSTR(string1, 'nevermore\\d') AS substring,
+   *        REGEXP_INSTR( string1, 'nevermore\\d') AS position
+   *   FROM demo1
+   *   ORDER BY id;
+   *
+   *   +----+-------------------------------------+------------+----------+
+   * | ID | STRING1                             | SUBSTRING  | POSITION |
+   * |----+-------------------------------------+------------+----------|
+   * |  1 | nevermore1, nevermore2, nevermore3. | nevermore1 |        1 |
+   * +----+-------------------------------------+------------+----------+
+   * }
+ * + * The position is not zero based, but 1 based index. Returns 0 if substr could not be found in + * str. + * + * @param str Column on which instr has to be applied + * @param substring Pattern to be retrieved + * @return A null if either of the arguments are null. + * @since 1.15.0 + */ + public static Column instr(Column str, String substring) { + return new Column(com.snowflake.snowpark.functions.instr(str.toScalaColumn(), substring)); + } + + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 + * 03:40:00.0'. + * + *

For Example + * + *

{@code
+   * ALTER SESSION SET TIMEZONE = 'America/Los_Angeles';
+   * SELECT TO_TIMESTAMP_TZ('2024-04-05 01:02:03');
+   *  +----------------------------------------+
+   * | TO_TIMESTAMP_TZ('2024-04-05 01:02:03') |
+   * |----------------------------------------|
+   * | 2024-04-05 01:02:03.000 -0700          |
+   * +----------------------------------------+
+   * }
+ * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing + * the time zone ID that the input should be adjusted to. It should be in the format of either + * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as + * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example + * '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short + * names are not recommended to use because they can be ambiguous. + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` + * was an invalid value + * @since 1.15.0 + */ + public static Column from_utc_timestamp(Column ts) { + return new Column(com.snowflake.snowpark.functions.from_utc_timestamp(ts.toScalaColumn())); + } + + /** + * Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time zone, + * and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield '2017-07-14 + * 01:40:00.0'. + * + * @param ts A date, timestamp or string. If a string, the data must be in a format that can be + * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing + * the time zone ID that the input should be adjusted to. It should be in the format of either + * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as + * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example + * '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short + * names are not recommended to use because they can be ambiguous. + * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` + * was an invalid value + * @since 1.15.0 + */ + public static Column to_utc_timestamp(Column ts) { + return new Column(com.snowflake.snowpark.functions.to_utc_timestamp(ts.toScalaColumn())); + } + + /** + * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with + * HALF_EVEN round mode, and returns the result as a string column. + * + *

If d is 0, the result has no decimal point or fractional part. If d is less than 0, the + * result will be null. + * + * @param x numeric column to be transformed + * @param d Amount of decimal for the number format + * @return Number casted to the specific string format + * @since 1.15.0 + */ + public static Column format_number(Column x, Integer d) { + return new Column(com.snowflake.snowpark.functions.format_number(x.toScalaColumn(), d)); + } /* Returns a Column expression with values sorted in descending order. * diff --git a/src/main/java/com/snowflake/snowpark_java/Row.java b/src/main/java/com/snowflake/snowpark_java/Row.java index 0921a0d6..74475e2d 100644 --- a/src/main/java/com/snowflake/snowpark_java/Row.java +++ b/src/main/java/com/snowflake/snowpark_java/Row.java @@ -425,6 +425,64 @@ public Row getObject(int index) { return (Row) get(index); } + /** + * Returns the value at the specified column index and casts it to the desired type {@code T}. + * + *

Example: + * + *

{@code
+   * Row row = Row.create(1, "Alice", 95.5);
+   * row.getAs(0, Integer.class); // Returns 1 as an Int
+   * row.getAs(1, String.class); // Returns "Alice" as a String
+   * row.getAs(2, Double.class); // Returns 95.5 as a Double
+   * }
+ * + * @param index the zero-based column index within the row. + * @param clazz the {@code Class} object representing the type {@code T}. + * @param the expected type of the value at the specified column index. + * @return the value at the specified column index cast to type {@code T}. + * @throws ClassCastException if the value at the given index cannot be cast to type {@code T}. + * @throws ArrayIndexOutOfBoundsException if the column index is out of bounds. + * @since 1.15.0 + */ + @SuppressWarnings("unchecked") + public T getAs(int index, Class clazz) + throws ClassCastException, ArrayIndexOutOfBoundsException { + if (isNullAt(index)) { + return (T) get(index); + } + + if (clazz == Byte.class) { + return (T) (Object) getByte(index); + } + + if (clazz == Double.class) { + return (T) (Object) getDouble(index); + } + + if (clazz == Float.class) { + return (T) (Object) getFloat(index); + } + + if (clazz == Integer.class) { + return (T) (Object) getInt(index); + } + + if (clazz == Long.class) { + return (T) (Object) getLong(index); + } + + if (clazz == Short.class) { + return (T) (Object) getShort(index); + } + + if (clazz == Variant.class) { + return (T) getVariant(index); + } + + return (T) get(index); + } + /** * Generates a string value to represent the content of this row. * diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index ac929f11..40ec4ffa 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -331,6 +331,43 @@ class Row protected (values: Array[Any]) extends Serializable { getAs[Map[T, U]](index) } + /** Returns the value at the specified column index and casts it to the desired type `T`. + * + * Example: + * {{{ + * val row = Row(1, "Alice", 95.5) + * row.getAs[Int](0) // Returns 1 as an Int + * row.getAs[String](1) // Returns "Alice" as a String + * row.getAs[Double](2) // Returns 95.5 as a Double + * }}} + * + * @param index + * the zero-based column index within the row. + * @tparam T + * the expected type of the value at the specified column index. + * @return + * the value at the specified column index cast to type `T`. + * @throws ClassCastException + * if the value at the given index cannot be cast to type `T`. + * @throws ArrayIndexOutOfBoundsException + * if the column index is out of bounds. + * @group getter + * @since 1.15.0 + */ + def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = { + classTag.runtimeClass match { + case _ if isNullAt(index) => get(index).asInstanceOf[T] + case c if c == classOf[Byte] => getByte(index).asInstanceOf[T] + case c if c == classOf[Double] => getDouble(index).asInstanceOf[T] + case c if c == classOf[Float] => getFloat(index).asInstanceOf[T] + case c if c == classOf[Int] => getInt(index).asInstanceOf[T] + case c if c == classOf[Long] => getLong(index).asInstanceOf[T] + case c if c == classOf[Short] => getShort(index).asInstanceOf[T] + case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T] + case _ => get(index).asInstanceOf[T] + } + } + protected def convertValueToString(value: Any): String = value match { case null => "null" @@ -362,10 +399,7 @@ class Row protected (values: Array[Any]) extends Serializable { .map(convertValueToString) .mkString("Row[", ",", "]") - private def getAs[T](index: Int): T = get(index).asInstanceOf[T] - - private def getAnyValAs[T <: AnyVal](index: Int): T = + private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T = if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null") - else getAs[T](index) - + else getAs[T](index)(classTag) } diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index e3dad8ec..69970c5f 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -800,6 +800,206 @@ object functions { */ def pow(l: Column, r: Column): Column = builtin("pow")(l, r) + /** Returns a number (l) raised to the specified power (r). + * + * Example: + * {{{ + * val df = session.sql( + * "select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)") + * df.select(col("base"), col("exponent"), pow(col("base"), "exponent").as("result")).show() + * + * ---------------------------------------------- + * |"BASE" |"EXPONENT" |"RESULT" | + * ---------------------------------------------- + * |0.1 |2.0 |0.010000000000000002 | + * |2.0 |3.0 |8.0 | + * |2.0 |0.5 |1.4142135623730951 | + * |2.0 |-1.0 |0.5 | + * ---------------------------------------------- + * }}} + * + * @param l + * The numeric column representing the base. + * @param r + * The name of the numeric column representing the exponent. + * @return + * A column containing the result of raising `l` to the power of `r`. + * @group num_func + * @since 1.15.0 + */ + def pow(l: Column, r: String): Column = pow(l, col(r)) + + /** Returns a number (l) raised to the specified power (r). + * + * Example: + * {{{ + * val df = session.sql( + * "select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)") + * df.select(col("base"), col("exponent"), pow("base", col("exponent")).as("result")).show() + * + * ---------------------------------------------- + * |"BASE" |"EXPONENT" |"RESULT" | + * ---------------------------------------------- + * |0.1 |2.0 |0.010000000000000002 | + * |2.0 |3.0 |8.0 | + * |2.0 |0.5 |1.4142135623730951 | + * |2.0 |-1.0 |0.5 | + * ---------------------------------------------- + * }}} + * + * @param l + * The name of the numeric column representing the base. + * @param r + * The numeric column representing the exponent. + * @return + * A column containing the result of raising `l` to the power of `r`. + * @group num_func + * @since 1.15.0 + */ + def pow(l: String, r: Column): Column = pow(col(l), r) + + /** Returns a number (l) raised to the specified power (r). + * + * Example: + * {{{ + * val df = session.sql( + * "select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)") + * df.select(col("base"), col("exponent"), pow("base", "exponent").as("result")).show() + * + * ---------------------------------------------- + * |"BASE" |"EXPONENT" |"RESULT" | + * ---------------------------------------------- + * |0.1 |2.0 |0.010000000000000002 | + * |2.0 |3.0 |8.0 | + * |2.0 |0.5 |1.4142135623730951 | + * |2.0 |-1.0 |0.5 | + * ---------------------------------------------- + * }}} + * + * @param l + * The name of the numeric column representing the base. + * @param r + * The name of the numeric column representing the exponent. + * @return + * A column containing the result of raising `l` to the power of `r`. + * @group num_func + * @since 1.15.0 + */ + def pow(l: String, r: String): Column = pow(col(l), col(r)) + + /** Returns a number (l) raised to the specified power (r). + * + * Example: + * {{{ + * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)") + * df.select(col("base"), lit(2.0).as("exponent"), pow(col("base"), 2.0).as("result")).show() + * + * ---------------------------------- + * |"BASE" |"EXPONENT" |"RESULT" | + * ---------------------------------- + * |0.5 |2.0 |0.25 | + * |2.0 |2.0 |4.0 | + * |2.5 |2.0 |6.25 | + * |4.0 |2.0 |16.0 | + * ---------------------------------- + * }}} + * + * @param l + * The numeric column representing the base. + * @param r + * The value of the exponent. + * @return + * A column containing the result of raising `l` to the power of `r`. + * @group num_func + * @since 1.15.0 + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r)) + + /** Returns a number (l) raised to the specified power (r). + * + * Example: + * {{{ + * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)") + * df.select(col("base"), lit(2.0).as("exponent"), pow("base", 2.0).as("result")).show() + * + * ---------------------------------- + * |"BASE" |"EXPONENT" |"RESULT" | + * ---------------------------------- + * |0.5 |2.0 |0.25 | + * |2.0 |2.0 |4.0 | + * |2.5 |2.0 |6.25 | + * |4.0 |2.0 |16.0 | + * ---------------------------------- + * }}} + * + * @param l + * The name of the numeric column representing the base. + * @param r + * The value of the exponent. + * @return + * A column containing the result of raising `l` to the power of `r`. + * @group num_func + * @since 1.15.0 + */ + def pow(l: String, r: Double): Column = pow(col(l), r) + + /** Returns a number (l) raised to the specified power (r). + * + * Example: + * {{{ + * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)") + * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, col("exponent")).as("result")) + * .show() + * + * -------------------------------------------- + * |"BASE" |"EXPONENT" |"RESULT" | + * -------------------------------------------- + * |2.0 |0.5 |1.4142135623730951 | + * |2.0 |2.0 |4.0 | + * |2.0 |2.5 |5.656854249492381 | + * |2.0 |4.0 |16.0 | + * -------------------------------------------- + * }}} + * + * @param l + * The value of the base. + * @param r + * The numeric column representing the exponent. + * @return + * A column containing the result of raising `l` to the power of `r`. + * @group num_func + * @since 1.15.0 + */ + def pow(l: Double, r: Column): Column = pow(lit(l), r) + + /** Returns a number (l) raised to the specified power (r). + * + * Example: + * {{{ + * val df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)") + * df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, "exponent").as("result")).show() + * + * -------------------------------------------- + * |"BASE" |"EXPONENT" |"RESULT" | + * -------------------------------------------- + * |2.0 |0.5 |1.4142135623730951 | + * |2.0 |2.0 |4.0 | + * |2.0 |2.5 |5.656854249492381 | + * |2.0 |4.0 |16.0 | + * -------------------------------------------- + * }}} + * + * @param l + * The value of the base. + * @param r + * The name of the numeric column representing the exponent. + * @return + * A column containing the result of raising `l` to the power of `r`. + * @group num_func + * @since 1.15.0 + */ + def pow(l: Double, r: String): Column = pow(l, col(r)) + /** Rounds the numeric values of the given column `e` to the `scale` decimal places using the half * away from zero rounding mode. * @@ -3255,6 +3455,111 @@ object functions { */ def monotonically_increasing_id(): Column = builtin("seq8")() + /** Returns number of months between dates `start` and `end`. + * + * A whole number is returned if both inputs have the same day of month or both are the last day + * of their respective months. Otherwise, the difference is calculated assuming 31 days per + * month. + * + * For example: + * {{{ + * months_between("2017-11-14", "2017-07-14") // returns 4.0 + * months_between("2017-01-01", "2017-01-10") // returns 0.29032258 + * months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5 + * }}} + * @since 1.15.0 + * @param end + * Column name. If a string, the data must be in a format that can be cast to a timestamp, such + * as yyyy-MM-dd or yyyy-MM-dd HH:mm:ss.SSSS + * @param start + * Column name . If a string, the data must be in a format that can cast to a timestamp, such + * as yyyy-MM-dd or yyyy-MM-dd HH:mm:ss.SSSS + * @return + * A double, or null if either end or start were strings that could not be cast to a timestamp. + * Negative if end is before start + */ + def months_between(end: String, start: String): Column = + builtin("MONTHS_BETWEEN")(col(end), col(start)) + + /** Locate the position of the first occurrence of substr column in the given string. Returns null + * if either of the arguments are null. For example SELECT id, string1, REGEXP_SUBSTR(string1, + * 'nevermore\\d') AS substring, REGEXP_INSTR( string1, 'nevermore\\d') AS position FROM demo1 + * ORDER BY id; + * | ID | STRING1 | SUBSTRING | POSITION | + * |:-------------------------------------------------------------------|:------------------------------------|:-----------|:---------| + * | ----+-------------------------------------+------------+---------- | | | | + * | 1 | nevermore1, nevermore2, nevermore3. | nevermore1 | 1 | + * + * @since 1.15.0 + * @note + * The position is not zero based, but 1 based index. Returns 0 if substr could not be found in + * str. + */ + def instr(str: Column, substring: String): Column = builtin("REGEXP_INSTR")(str, substring) + + /** Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders + * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14 + * 03:40:00.0'. ALTER SESSION SET TIMEZONE = 'America/Los_Angeles'; SELECT + * TO_TIMESTAMP_TZ('2024-04-05 01:02:03'); + * | TO_TIMESTAMP_TZ('2024-04-05 01:02:03') | + * |:---------------------------------------| + * | 2024-04-05 01:02:03.000 -0700 | + * + * @since 1.15.0 + * @param ts + * A date, timestamp or string. If a string, the data must be in a format that can be cast to a + * timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing the time + * zone ID that the input should be adjusted to. It should be in the format of either + * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as + * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example '-08:00' + * or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short names are + * not recommended to use because they can be ambiguous. + * @return + * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was + * an invalid value + */ + def from_utc_timestamp(ts: Column): Column = + builtin("TO_TIMESTAMP_TZ")(ts) + + /** Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in the given time + * zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield + * '2017-07-14 01:40:00.0'. + * @since 1.15.0 + * @param ts + * A date, timestamp or string. If a string, the data must be in a format that can be cast to a + * timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS` A string detailing the time + * zone ID that the input should be adjusted to. It should be in the format of either + * region-based zone IDs or zone offsets. Region IDs must have the form 'area/city', such as + * 'America/Los_Angeles'. Zone offsets must be in the format '(+|-)HH:mm', for example '-08:00' + * or '+01:00'. Also 'UTC' and 'Z' are supported as aliases of '+00:00'. Other short names are + * not recommended to use because they can be ambiguous. + * @return + * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was + * an invalid value + */ + def to_utc_timestamp(ts: Column): Column = builtin("TO_TIMESTAMP_TZ")(ts) + + /** Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with + * HALF_EVEN round mode, and returns the result as a string column. + * @since 1.15.0 + * If d is 0, the result has no decimal point or fractional part. If d is less than 0, the + * result will be null. + * + * @param x + * numeric column to be transformed + * @param d + * Amount of decimal for the number format + * + * @return + * Number casted to the specific string format + */ + def format_number(x: Column, d: Int): Column = { + if (d < 0) { + lit(null) + } else { + builtin("TO_VARCHAR")(x, if (d > 0) s"999,999.${"0" * d}" else "999,999") + } + } /* Returns a Column expression with values sorted in descending order. * Example: * {{{ diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 0d2d2b30..89e9cbb5 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -593,11 +593,46 @@ public void pow() { DataFrame df = getSession().sql("select * from values(0.1, 0.5),(0.2, 0.6),(0.3, 0.7) as T(a,b)"); Row[] expected = { - Row.create(0.31622776601683794), - Row.create(0.3807307877431757), - Row.create(0.4305116202499342) + Row.create( + 0.31622776601683794, + 0.31622776601683794, + 0.31622776601683794, + 0.31622776601683794, + 0.15848931924611134, + 0.15848931924611134, + 0.6324555320336759, + 0.6324555320336759), + Row.create( + 0.3807307877431757, + 0.3807307877431757, + 0.3807307877431757, + 0.3807307877431757, + 0.27594593229224296, + 0.27594593229224296, + 0.5770799623628855, + 0.5770799623628855), + Row.create( + 0.4305116202499342, + 0.4305116202499342, + 0.4305116202499342, + 0.4305116202499342, + 0.3816778909618176, + 0.3816778909618176, + 0.526552881733695, + 0.526552881733695) }; - checkAnswer(df.select(Functions.pow(df.col("a"), df.col("b"))), expected, false); + checkAnswer( + df.select( + Functions.pow(df.col("a"), df.col("b")), + Functions.pow(df.col("a"), "b"), + Functions.pow("a", df.col("b")), + Functions.pow("a", "b"), + Functions.pow(df.col("a"), 0.8), + Functions.pow("a", 0.8), + Functions.pow(0.4, df.col("b")), + Functions.pow(0.4, "b")), + expected, + false); } @Test @@ -3100,4 +3135,55 @@ public void unhex() { Row[] expected = {Row.create("1"), Row.create("2"), Row.create("3")}; checkAnswer(df.select(Functions.unhex(Functions.col("a"))), expected, false); } + + @Test + public void months_between() { + DataFrame df = + getSession() + .sql( + "select * from values('2010-07-02'::Date,'2010-08-02'::Date), " + + "('2020-08-02'::Date,'2020-12-02'::Date) as t(a,b)"); + Row[] expected = {Row.create(1.000000), Row.create(4.000000)}; + checkAnswer(df.select(Functions.months_between("b", "a")), expected, false); + } + + @Test + public void instr() { + DataFrame df = + getSession() + .sql( + "select * from values('It was the best of times, it was the worst of times') as t(a)"); + Row[] expected = {Row.create(4)}; + checkAnswer(df.select(Functions.instr(df.col("a"), "was")), expected, false); + } + + @Test + public void format_number1() { + DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)"); + Row[] expected = {Row.create("1"), Row.create("2"), Row.create("3")}; + checkAnswer( + df.select(Functions.ltrim(Functions.format_number(df.col("a"), 0))), expected, false); + } + + @Test + public void format_number2() { + DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)"); + Row[] expected = {Row.create("1.00"), Row.create("2.00"), Row.create("3.00")}; + checkAnswer( + df.select(Functions.ltrim(Functions.format_number(df.col("a"), 2))), expected, false); + } + + @Test + public void from_utc_timestamp() { + DataFrame df = getSession().sql("select * from values('2024-04-05 01:02:03') as t(a)"); + Row[] expected = {Row.create(Timestamp.valueOf("2024-04-05 01:02:03.0"))}; + checkAnswer(df.select(Functions.from_utc_timestamp(df.col("a"))), expected, false); + } + + @Test + public void to_utc_timestamp() { + DataFrame df = getSession().sql("select * from values('2024-04-05 01:02:03') as t(a)"); + Row[] expected = {Row.create(Timestamp.valueOf("2024-04-05 01:02:03.0"))}; + checkAnswer(df.select(Functions.to_utc_timestamp(df.col("a"))), expected, false); + } } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index 0570c421..f8918292 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -1,5 +1,7 @@ package com.snowflake.snowpark_test; +import static org.junit.Assert.assertThrows; + import com.snowflake.snowpark_java.DataFrame; import com.snowflake.snowpark_java.Row; import com.snowflake.snowpark_java.types.*; @@ -7,10 +9,13 @@ import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.TimeZone; import org.junit.Test; public class JavaRowSuite extends TestBase { @@ -429,4 +434,172 @@ public void testGetRow() { }, getSession()); } + + @Test + public void getAs() { + long milliseconds = System.currentTimeMillis(); + + StructType schema = + StructType.create( + new StructField("c01", DataTypes.BinaryType), + new StructField("c02", DataTypes.BooleanType), + new StructField("c03", DataTypes.ByteType), + new StructField("c04", DataTypes.DateType), + new StructField("c05", DataTypes.DoubleType), + new StructField("c06", DataTypes.FloatType), + new StructField("c07", DataTypes.GeographyType), + new StructField("c08", DataTypes.GeometryType), + new StructField("c09", DataTypes.IntegerType), + new StructField("c10", DataTypes.LongType), + new StructField("c11", DataTypes.ShortType), + new StructField("c12", DataTypes.StringType), + new StructField("c13", DataTypes.TimeType), + new StructField("c14", DataTypes.TimestampType), + new StructField("c15", DataTypes.VariantType)); + + Row[] data = { + Row.create( + new byte[] {1, 2}, + true, + Byte.MIN_VALUE, + Date.valueOf("2024-01-01"), + Double.MIN_VALUE, + Float.MIN_VALUE, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Integer.MIN_VALUE, + Long.MIN_VALUE, + Short.MIN_VALUE, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1)) + }; + + DataFrame df = getSession().createDataFrame(data, schema); + Row row = df.collect()[0]; + + assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2}); + assert row.getAs(1, Boolean.class); + assert row.getAs(2, Byte.class) == Byte.MIN_VALUE; + assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01")); + assert row.getAs(4, Double.class) == Double.MIN_VALUE; + assert row.getAs(5, Float.class) == Float.MIN_VALUE; + assert row.getAs(6, Geography.class) + .equals( + Geography.fromGeoJSON( + "{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(7, Geometry.class) + .equals( + Geometry.fromGeoJSON( + "{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}")); + assert row.getAs(8, Integer.class) == Integer.MIN_VALUE; + assert row.getAs(9, Long.class) == Long.MIN_VALUE; + assert row.getAs(10, Short.class) == Short.MIN_VALUE; + assert row.getAs(11, String.class).equals("string"); + assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04")); + assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds)); + assert row.getAs(14, Variant.class).equals(new Variant(1)); + + Row finalRow = row; + assertThrows( + ClassCastException.class, + () -> { + Boolean b = finalRow.getAs(0, Boolean.class); + }); + assertThrows(ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class)); + + data = + new Row[] { + Row.create( + null, null, null, null, null, null, null, null, null, null, null, null, null, null, + null) + }; + + df = getSession().createDataFrame(data, schema); + row = df.collect()[0]; + + assert row.getAs(0, byte[].class) == null; + assert row.getAs(1, Boolean.class) == null; + assert row.getAs(2, Byte.class) == null; + assert row.getAs(3, Date.class) == null; + assert row.getAs(4, Double.class) == null; + assert row.getAs(5, Float.class) == null; + assert row.getAs(6, Geography.class) == null; + assert row.getAs(7, Geometry.class) == null; + assert row.getAs(8, Integer.class) == null; + assert row.getAs(9, Long.class) == null; + assert row.getAs(10, Short.class) == null; + assert row.getAs(11, String.class) == null; + assert row.getAs(12, Time.class) == null; + assert row.getAs(13, Timestamp.class) == null; + assert row.getAs(14, Variant.class) == null; + } + + @Test + public void getAsWithStructuredMap() { + structuredTypeTest( + () -> { + String query = + "SELECT " + + "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1," + + "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2," + + "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + Map map1 = row.getAs(0, Map.class); + assert (Long) map1.get("a") == 1L; + assert (Long) map1.get("b") == 2L; + + Map map2 = row.getAs(1, Map.class); + assert map2.get(1L).equals("a"); + assert map2.get(2L).equals("b"); + + Map map3 = row.getAs(2, Map.class); + Map map3ExpectedInnerMap = new HashMap<>(); + map3ExpectedInnerMap.put("a", 1L); + map3ExpectedInnerMap.put("b", 2L); + assert map3.get(1L).equals(map3ExpectedInnerMap); + assert map3.get(2L).equals(Collections.singletonMap("c", 3L)); + }, + getSession()); + } + + @Test + public void getAsWithStructuredArray() { + structuredTypeTest( + () -> { + TimeZone oldTimeZone = TimeZone.getDefault(); + try { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")); + + String query = + "SELECT " + + "[1,2,3]::ARRAY(NUMBER) AS arr1," + + "['a','b']::ARRAY(VARCHAR) AS arr2," + + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3," + + "[[1,2]]::ARRAY(ARRAY) AS arr4"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + ArrayList array1 = row.getAs(0, ArrayList.class); + assert array1.equals(Arrays.asList(1L, 2L, 3L)); + + ArrayList array2 = row.getAs(1, ArrayList.class); + assert array2.equals(Arrays.asList("a", "b")); + + ArrayList array3 = row.getAs(2, ArrayList.class); + assert array3.equals(Collections.singletonList(new Timestamp(31000000000L))); + + ArrayList array4 = row.getAs(3, ArrayList.class); + assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]")); + } finally { + TimeZone.setDefault(oldTimeZone); + } + }, + getSession()); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 546164fc..3ac7996b 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -353,10 +353,25 @@ trait FunctionSuite extends TestData { test("pow") { checkAnswer( - double2.select(pow(col("A"), col("B"))), - Seq(Row(0.31622776601683794), Row(0.3807307877431757), Row(0.4305116202499342)), + double2.select( + pow(col("A"), col("B")), + pow(col("A"), "B"), + pow("A", col("B")), + pow("A", "B"), + pow(col("A"), 0.8), + pow("A", 0.8), + pow(0.4, col("B")), + pow(0.4, "B")), + Seq( + Row(0.31622776601683794, 0.31622776601683794, 0.31622776601683794, 0.31622776601683794, + 0.15848931924611134, 0.15848931924611134, 0.6324555320336759, 0.6324555320336759), + Row(0.3807307877431757, 0.3807307877431757, 0.3807307877431757, 0.3807307877431757, + 0.27594593229224296, 0.27594593229224296, 0.5770799623628855, 0.5770799623628855), + Row(0.4305116202499342, 0.4305116202499342, 0.4305116202499342, 0.4305116202499342, + 0.3816778909618176, 0.3816778909618176, 0.526552881733695, 0.526552881733695)), sort = false) } + test("shiftleft shiftright") { checkAnswer( integer1.select(bitshiftleft(col("A"), lit(1)), bitshiftright(col("A"), lit(1))), @@ -2455,6 +2470,58 @@ trait FunctionSuite extends TestData { Seq(Row("1"), Row("2"), Row("3")), sort = false) } + test("months_between") { + val months_between = functions.builtin("MONTHS_BETWEEN") + val input = Seq( + (Date.valueOf("2010-08-02"), Date.valueOf("2010-07-02")), + (Date.valueOf("2020-12-02"), Date.valueOf("2020-08-02"))) + .toDF("a", "b") + checkAnswer( + input.select(months_between(col("a"), col("b"))), + Seq(Row((1.000000)), Row(4.000000)), + sort = false) + } + + test("instr") { + val df = Seq("It was the best of times, it was the worst of times").toDF("a") + checkAnswer(df.select(instr(col("a"), "was")), Seq(Row(4)), sort = false) + } + + test("format_number1") { + + checkAnswer( + number3.select(ltrim(format_number(col("a"), 0))), + Seq(Row(("1")), Row(("2")), Row(("3"))), + sort = false) + } + + test("format_number2") { + + checkAnswer( + number3.select(ltrim(format_number(col("a"), 2))), + Seq(Row(("1.00")), Row(("2.00")), Row(("3.00"))), + sort = false) + } + + test("format_number3") { + + checkAnswer( + number3.select(ltrim(format_number(col("a"), -1))), + Seq(Row((null)), Row((null)), Row((null))), + sort = false) + } + + test("from_utc_timestamp") { + val expected = Seq(Timestamp.valueOf("2024-04-05 01:02:03.0")).toDF("a") + val data = Seq("2024-04-05 01:02:03").toDF("a") + checkAnswer(data.select(from_utc_timestamp(col("a"))), expected, sort = false) + } + + test("to_utc_timestamp") { + val expected = Seq(Timestamp.valueOf("2024-04-05 01:02:03.0")).toDF("a") + val data = Seq("2024-04-05 01:02:03").toDF("a") + checkAnswer(data.select(to_utc_timestamp(col("a"))), expected, sort = false) + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 54aba687..8de21e92 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -1,11 +1,12 @@ package com.snowflake.snowpark_test -import com.snowflake.snowpark.types.{Geography, Geometry, Variant} +import com.snowflake.snowpark.types._ import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException} -import java.sql.{Date, Timestamp} +import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util +import java.util.TimeZone class RowSuite extends SNTestBase { @@ -233,6 +234,161 @@ class RowSuite extends SNTestBase { assert(err.message.matches(msg)) } + test("getAs") { + val milliseconds = System.currentTimeMillis() + + val schema = StructType( + Seq( + StructField("c01", BinaryType), + StructField("c02", BooleanType), + StructField("c03", ByteType), + StructField("c04", DateType), + StructField("c05", DoubleType), + StructField("c06", FloatType), + StructField("c07", GeographyType), + StructField("c08", GeometryType), + StructField("c09", IntegerType), + StructField("c10", LongType), + StructField("c11", ShortType), + StructField("c12", StringType), + StructField("c13", TimeType), + StructField("c14", TimestampType), + StructField("c15", VariantType))) + + var data = Seq( + Row( + Array[Byte](1, 9), + true, + Byte.MinValue, + Date.valueOf("2024-01-01"), + Double.MinValue, + Float.MinPositiveValue, + Geography.fromGeoJSON("POINT(30 10)"), + Geometry.fromGeoJSON("POINT(20 40)"), + Int.MinValue, + Long.MinValue, + Short.MinValue, + "string", + Time.valueOf("16:23:04"), + new Timestamp(milliseconds), + new Variant(1))) + + var df = session.createDataFrame(data, schema) + var row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) sameElements Array[Byte](1, 9)) + assert(row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == Byte.MinValue) + assert(row.getAs[Date](3) == Date.valueOf("2024-01-01")) + assert(row.getAs[Double](4) == Double.MinValue) + assert(row.getAs[Float](5) == Float.MinPositiveValue) + assert(row.getAs[Geography](6) == Geography.fromGeoJSON("""{ + | "coordinates": [ + | 30, + | 10 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Geometry](7) == Geometry.fromGeoJSON("""{ + | "coordinates": [ + | 2.000000000000000e+01, + | 4.000000000000000e+01 + | ], + | "type": "Point" + |}""".stripMargin)) + assert(row.getAs[Int](8) == Int.MinValue) + assert(row.getAs[Long](9) == Long.MinValue) + assert(row.getAs[Short](10) == Short.MinValue) + assert(row.getAs[String](11) == "string") + assert(row.getAs[Time](12) == Time.valueOf("16:23:04")) + assert(row.getAs[Timestamp](13) == new Timestamp(milliseconds)) + assert(row.getAs[Variant](14) == new Variant(1)) + assertThrows[ClassCastException](row.getAs[Boolean](0)) + assertThrows[ArrayIndexOutOfBoundsException](row.getAs[Boolean](-1)) + + data = Seq( + Row(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null)) + + df = session.createDataFrame(data, schema) + row = df.collect()(0) + + assert(row.getAs[Array[Byte]](0) == null) + assert(!row.getAs[Boolean](1)) + assert(row.getAs[Byte](2) == 0) + assert(row.getAs[Date](3) == null) + assert(row.getAs[Double](4) == 0) + assert(row.getAs[Float](5) == 0) + assert(row.getAs[Geography](6) == null) + assert(row.getAs[Geometry](7) == null) + assert(row.getAs[Int](8) == 0) + assert(row.getAs[Long](9) == 0) + assert(row.getAs[Short](10) == 0) + assert(row.getAs[String](11) == null) + assert(row.getAs[Time](12) == null) + assert(row.getAs[Timestamp](13) == null) + assert(row.getAs[Variant](14) == null) + } + + test("getAs with structured map") { + structuredTypeTest { + val query = + """SELECT + | {'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1, + | {'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2, + | {'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val map1 = row.getAs[Map[String, Long]](0) + assert(map1("a") == 1L) + assert(map1("b") == 2L) + + val map2 = row.getAs[Map[Long, String]](1) + assert(map2(1) == "a") + assert(map2(2) == "b") + + val map3 = row.getAs[Map[Long, Map[String, Long]]](2) + assert(map3(1) == Map("a" -> 1, "b" -> 2)) + assert(map3(2) == Map("c" -> 3)) + } + } + + test("getAs with structured array") { + structuredTypeTest { + val oldTimeZone = TimeZone.getDefault + try { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")) + + val query = + """SELECT + | [1,2,3]::ARRAY(NUMBER) AS arr1, + | ['a','b']::ARRAY(VARCHAR) AS arr2, + | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3, + | [[1,2]]::ARRAY(ARRAY) AS arr4 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val array1 = row.getAs[Array[Object]](0) + assert(array1 sameElements Array(1, 2, 3)) + + val array2 = row.getAs[Array[Object]](1) + assert(array2 sameElements Array("a", "b")) + + val array3 = row.getAs[Array[Object]](2) + assert(array3 sameElements Array(new Timestamp(31000000000L))) + + val array4 = row.getAs[Array[Object]](3) + assert(array4 sameElements Array("[\n 1,\n 2\n]")) + } finally { + TimeZone.setDefault(oldTimeZone) + } + } + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3)