Skip to content

Commit

Permalink
Merge branch 'SNOW-802269-months_between_format_number' of https://gi…
Browse files Browse the repository at this point in the history
…thub.com/snowflakedb/snowpark-java-scala into SNOW-802269-months_between_format_number
  • Loading branch information
sfc-gh-sjayabalan committed Sep 10, 2024
2 parents fc16969 + 7bdf07d commit 21344d7
Show file tree
Hide file tree
Showing 8 changed files with 877 additions and 13 deletions.
196 changes: 196 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,202 @@ public static Column pow(Column l, Column r) {
return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r.toScalaColumn()));
}

/**
* Returns a number (l) raised to the specified power (r).
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
* df.select(col("base"), col("exponent"), pow(col("base"), "exponent").as("result")).show();
*
* ----------------------------------------------
* |"BASE" |"EXPONENT" |"RESULT" |
* ----------------------------------------------
* |0.1 |2.0 |0.010000000000000002 |
* |2.0 |3.0 |8.0 |
* |2.0 |0.5 |1.4142135623730951 |
* |2.0 |-1.0 |0.5 |
* ----------------------------------------------
* }</pre>
*
* @param l The numeric column representing the base.
* @param r The name of the numeric column representing the exponent.
* @return A column containing the result of raising {@code l} to the power of {@code r}.
* @since 1.15.0
*/
public static Column pow(Column l, String r) {
return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r));
}

/**
* Returns a number (l) raised to the specified power (r).
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
* df.select(col("base"), col("exponent"), pow("base", col("exponent")).as("result")).show();
*
* ----------------------------------------------
* |"BASE" |"EXPONENT" |"RESULT" |
* ----------------------------------------------
* |0.1 |2.0 |0.010000000000000002 |
* |2.0 |3.0 |8.0 |
* |2.0 |0.5 |1.4142135623730951 |
* |2.0 |-1.0 |0.5 |
* ----------------------------------------------
* }</pre>
*
* @param l The name of the numeric column representing the base.
* @param r The numeric column representing the exponent.
* @return A column containing the result of raising {@code l} to the power of {@code r}.
* @since 1.15.0
*/
public static Column pow(String l, Column r) {
return new Column(com.snowflake.snowpark.functions.pow(l, r.toScalaColumn()));
}

/**
* Returns a number (l) raised to the specified power (r).
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (0.1, 2), (2, 3), (2, 0.5), (2, -1)) as T(base, exponent)");
* df.select(col("base"), col("exponent"), pow("base", "exponent").as("result")).show();
*
* ----------------------------------------------
* |"BASE" |"EXPONENT" |"RESULT" |
* ----------------------------------------------
* |0.1 |2.0 |0.010000000000000002 |
* |2.0 |3.0 |8.0 |
* |2.0 |0.5 |1.4142135623730951 |
* |2.0 |-1.0 |0.5 |
* ----------------------------------------------
* }</pre>
*
* @param l The name of the numeric column representing the base.
* @param r The name of the numeric column representing the exponent.
* @return A column containing the result of raising {@code l} to the power of {@code r}.
* @since 1.15.0
*/
public static Column pow(String l, String r) {
return new Column(com.snowflake.snowpark.functions.pow(l, r));
}

/**
* Returns a number (l) raised to the specified power (r).
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)");
* df.select(col("base"), lit(2.0).as("exponent"), pow(col("base"), 2.0).as("result")).show();
*
* ----------------------------------
* |"BASE" |"EXPONENT" |"RESULT" |
* ----------------------------------
* |0.5 |2.0 |0.25 |
* |2.0 |2.0 |4.0 |
* |2.5 |2.0 |6.25 |
* |4.0 |2.0 |16.0 |
* ----------------------------------
* }</pre>
*
* @param l The numeric column representing the base.
* @param r The value of the exponent.
* @return A column containing the result of raising {@code l} to the power of {@code r}.
* @since 1.15.0
*/
public static Column pow(Column l, Double r) {
return new Column(com.snowflake.snowpark.functions.pow(l.toScalaColumn(), r));
}

/**
* Returns a number (l) raised to the specified power (r).
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(base)");
* df.select(col("base"), lit(2.0).as("exponent"), pow("base", 2.0).as("result")).show();
*
* ----------------------------------
* |"BASE" |"EXPONENT" |"RESULT" |
* ----------------------------------
* |0.5 |2.0 |0.25 |
* |2.0 |2.0 |4.0 |
* |2.5 |2.0 |6.25 |
* |4.0 |2.0 |16.0 |
* ----------------------------------
* }</pre>
*
* @param l The name of the numeric column representing the base.
* @param r The value of the exponent.
* @return A column containing the result of raising {@code l} to the power of {@code r}.
* @since 1.15.0
*/
public static Column pow(String l, Double r) {
return new Column(com.snowflake.snowpark.functions.pow(l, r));
}

/**
* Returns a number (l) raised to the specified power (r).
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)");
* df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, col("exponent")).as("result")).show();
*
* --------------------------------------------
* |"BASE" |"EXPONENT" |"RESULT" |
* --------------------------------------------
* |2.0 |0.5 |1.4142135623730951 |
* |2.0 |2.0 |4.0 |
* |2.0 |2.5 |5.656854249492381 |
* |2.0 |4.0 |16.0 |
* --------------------------------------------
* }</pre>
*
* @param l The value of the base.
* @param r The numeric column representing the exponent.
* @return A column containing the result of raising {@code l} to the power of {@code r}.
* @since 1.15.0
*/
public static Column pow(Double l, Column r) {
return new Column(com.snowflake.snowpark.functions.pow(l, r.toScalaColumn()));
}

/**
* Returns a number (l) raised to the specified power (r).
*
* <p>Example:
*
* <pre>{@code
* DataFrame df = session.sql("select * from (values (0.5), (2), (2.5), (4)) as T(exponent)");
* df.select(lit(2.0).as("base"), col("exponent"), pow(2.0, "exponent").as("result")).show();
*
* --------------------------------------------
* |"BASE" |"EXPONENT" |"RESULT" |
* --------------------------------------------
* |2.0 |0.5 |1.4142135623730951 |
* |2.0 |2.0 |4.0 |
* |2.0 |2.5 |5.656854249492381 |
* |2.0 |4.0 |16.0 |
* --------------------------------------------
* }</pre>
*
* @param l The value of the base.
* @param r The name of the numeric column representing the exponent.
* @return A column containing the result of raising {@code l} to the power of {@code r}.
* @since 1.15.0
*/
public static Column pow(Double l, String r) {
return new Column(com.snowflake.snowpark.functions.pow(l, r));
}

/**
* Rounds the numeric values of the given column {@code e} to the {@code scale} decimal places
* using the half away from zero rounding mode.
Expand Down
58 changes: 58 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,64 @@ public Row getObject(int index) {
return (Row) get(index);
}

/**
* Returns the value at the specified column index and casts it to the desired type {@code T}.
*
* <p>Example:
*
* <pre>{@code
* Row row = Row.create(1, "Alice", 95.5);
* row.getAs(0, Integer.class); // Returns 1 as an Int
* row.getAs(1, String.class); // Returns "Alice" as a String
* row.getAs(2, Double.class); // Returns 95.5 as a Double
* }</pre>
*
* @param index the zero-based column index within the row.
* @param clazz the {@code Class} object representing the type {@code T}.
* @param <T> the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type {@code T}.
* @throws ClassCastException if the value at the given index cannot be cast to type {@code T}.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @since 1.15.0
*/
@SuppressWarnings("unchecked")
public <T> T getAs(int index, Class<T> clazz)
throws ClassCastException, ArrayIndexOutOfBoundsException {
if (isNullAt(index)) {
return (T) get(index);
}

if (clazz == Byte.class) {
return (T) (Object) getByte(index);
}

if (clazz == Double.class) {
return (T) (Object) getDouble(index);
}

if (clazz == Float.class) {
return (T) (Object) getFloat(index);
}

if (clazz == Integer.class) {
return (T) (Object) getInt(index);
}

if (clazz == Long.class) {
return (T) (Object) getLong(index);
}

if (clazz == Short.class) {
return (T) (Object) getShort(index);
}

if (clazz == Variant.class) {
return (T) getVariant(index);
}

return (T) get(index);
}

/**
* Generates a string value to represent the content of this row.
*
Expand Down
40 changes: 35 additions & 5 deletions src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}

/**
* Returns the value at the specified column index and casts it to the desired type `T`.
*
* Example:
* {{{
* val row = Row(1, "Alice", 95.5)
* row.getAs[Int](0) // Returns 1 as an Int
* row.getAs[String](1) // Returns "Alice" as a String
* row.getAs[Double](2) // Returns 95.5 as a Double
* }}}
*
* @param index the zero-based column index within the row.
* @tparam T the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type `T`.
* @throws ClassCastException if the value at the given index cannot be cast to type `T`.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @group getter
* @since 1.15.0
*/
def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = {
classTag.runtimeClass match {
case _ if isNullAt(index) => get(index).asInstanceOf[T]
case c if c == classOf[Byte] => getByte(index).asInstanceOf[T]
case c if c == classOf[Double] => getDouble(index).asInstanceOf[T]
case c if c == classOf[Float] => getFloat(index).asInstanceOf[T]
case c if c == classOf[Int] => getInt(index).asInstanceOf[T]
case c if c == classOf[Long] => getLong(index).asInstanceOf[T]
case c if c == classOf[Short] => getShort(index).asInstanceOf[T]
case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T]
case _ => get(index).asInstanceOf[T]
}
}

protected def convertValueToString(value: Any): String =
value match {
case null => "null"
Expand Down Expand Up @@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable {
.map(convertValueToString)
.mkString("Row[", ",", "]")

private def getAs[T](index: Int): T = get(index).asInstanceOf[T]

private def getAnyValAs[T <: AnyVal](index: Int): T =
private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T =
if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null")
else getAs[T](index)

else getAs[T](index)(classTag)
}
Loading

0 comments on commit 21344d7

Please sign in to comment.