diff --git a/pom.xml b/pom.xml
index b4283c76..0f01902c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,7 +4,7 @@
4.0.0
com.snowflake
snowpark
- 1.9.0
+ 1.9.0-coveo-2
${project.artifactId}
Snowflake's DataFrame API
https://www.snowflake.com/
@@ -443,66 +443,10 @@
-
- org.scoverage
- scoverage-maven-plugin
- ${scoverage.plugin.version}
-
-
-
- check
-
- prepare-package
-
-
-
-
- org.apache.maven.plugins
- maven-gpg-plugin
- 1.6
-
-
- sign-deploy-artifacts
- package
-
- sign
-
-
-
-
-
- net.nicoulaj.maven.plugins
- checksum-maven-plugin
- 1.10
-
-
- package
-
- artifacts
-
-
-
-
-
- SHA-256
- md5
-
-
-
-
-
-
- maven-deploy-plugin
-
- true
-
-
-
-
@@ -516,25 +460,6 @@
- org.apache.maven.plugins
- maven-assembly-plugin
- 3.3.0
-
-
- generate-tar-zip
- none
-
-
- with-dependencies
- none
-
-
- fat-test
- none
-
-
-
-
maven-jar-plugin
3.1.0
@@ -551,52 +476,63 @@
-
- org.apache.maven.plugins
- maven-gpg-plugin
- 1.6
-
-
- sign-deploy-artifacts
- none
-
-
- sign-and-deploy-file
- deploy
-
- sign-and-deploy-file
-
-
- target/${project.artifactId}-${project.version}.jar
- ossrh
- https://oss.sonatype.org/service/local/staging/deploy/maven2
- pom.xml
- target/${project.artifactId}-${project.version}-javadoc.jar
- ${env.GPG_KEY_ID}
- ${env.GPG_KEY_PASSPHRASE}
-
-
-
-
-
-
-
- org.scoverage
- scoverage-maven-plugin
- ${scoverage.plugin.version}
-
-
-
- report-only
-
-
-
-
-
-
+
+
+
+ false
+
+ central
+ libs-release
+ https://maven.cloud.coveo.com/artifactory/libs-release
+
+
+
+ false
+
+ snapshots
+ libs-snapshot
+ https://maven.cloud.coveo.com/artifactory/libs-snapshot
+
+
+
+
+
+ false
+
+ central
+ plugins-release
+ https://maven.cloud.coveo.com/artifactory/plugins-release
+
+
+
+ false
+
+ snapshots
+ plugins-snapshot
+ https://maven.cloud.coveo.com/artifactory/plugins-snapshot
+
+
+
+
+
+ CoveoCloud
+ CoveoCloud-releases
+ https://maven.cloud.coveo.com/artifactory/libs-release-local
+
+
+ CoveoCloud
+ CoveoCloud-snapshots
+ https://maven.cloud.coveo.com/artifactory/libs-snapshot-local
+
+
+ ${project.artifactId}
+ ${project.artifactId}
+ ${project.artifactId}/
+
+
diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java
index 74cc39a8..b762eb80 100644
--- a/src/main/java/com/snowflake/snowpark_java/Functions.java
+++ b/src/main/java/com/snowflake/snowpark_java/Functions.java
@@ -79,7 +79,7 @@ public static Column toScalar(DataFrame df) {
* @return The result column
*/
public static Column lit(Object literal) {
- return new Column(com.snowflake.snowpark.functions.lit(literal));
+ return new Column(com.snowflake.snowpark.functions.lit(JavaUtils.toScala(literal)));
}
/**
@@ -2365,14 +2365,13 @@ public static Column regexp_count(Column strExpr, Column pattern) {
*/
public static Column regexp_replace(Column strExpr, Column pattern) {
return new Column(
- com.snowflake.snowpark.functions.regexp_replace(
- strExpr.toScalaColumn(), pattern.toScalaColumn()));
+ com.snowflake.snowpark.functions.regexp_replace(
+ strExpr.toScalaColumn(), pattern.toScalaColumn()));
}
/**
- * Returns the subject with the specified pattern (or all occurrences of the pattern)
- * replaced by a replacement string. If no matches are found, returns the original
- * subject.
+ * Returns the subject with the specified pattern (or all occurrences of the pattern) replaced by
+ * a replacement string. If no matches are found, returns the original subject.
*
* @param strExpr The input string
* @param pattern The pattern
@@ -2382,8 +2381,8 @@ public static Column regexp_replace(Column strExpr, Column pattern) {
*/
public static Column regexp_replace(Column strExpr, Column pattern, Column replacement) {
return new Column(
- com.snowflake.snowpark.functions.regexp_replace(
- strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn()));
+ com.snowflake.snowpark.functions.regexp_replace(
+ strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn()));
}
/**
diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala
index 49f38593..2aaf9051 100644
--- a/src/main/scala/com/snowflake/snowpark/functions.scala
+++ b/src/main/scala/com/snowflake/snowpark/functions.scala
@@ -1818,6 +1818,19 @@ object functions {
builtin("regexp_replace")(strExpr, pattern, replacement)
+ def regexp_extract_all(strExpr: Column,
+ pattern: Column,
+ position: Column = lit(1),
+ occurrence: Column = lit(1),
+ regex_parameters: Column = lit("c"),
+ group_num_opt: Option[Column] = None): Column =
+ builtin("regexp_extract_all")((Seq(
+ strExpr,
+ pattern,
+ position,
+ occurrence,
+ regex_parameters) ++ group_num_opt): _*)
+
/**
* Removes all occurrences of a specified strExpr,
* and optionally replaces them with replacement.
diff --git a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
index 08a92b6b..9f4b6784 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
@@ -414,4 +414,16 @@ object JavaUtils {
}
}
+ def toScala(element: Any): Any = {
+ import collection.JavaConverters._
+ element match {
+ case map: java.util.Map[_, _] =>
+ mapAsScalaMap(map).map {
+ case (k, v) => toScala(k) -> toScala(v)
+ }.toMap
+ case iterable: java.lang.Iterable[_] => iterableAsScalaIterable(iterable).map(toScala)
+ case iterator: java.util.Iterator[_] => asScalaIterator(iterator).map(toScala)
+ case _ => element
+ }
+ }
}
diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala
index 598dd166..c55b45ea 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala
@@ -1,20 +1,20 @@
package com.snowflake.snowpark.internal.analyzer
+
import com.snowflake.snowpark.internal.Utils
+import com.snowflake.snowpark.types._
+import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat
+import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.util.TimeZone
-import java.math.{BigDecimal => JBigDecimal}
-
-import com.snowflake.snowpark.types._
-import com.snowflake.snowpark.types.convertToSFType
import javax.xml.bind.DatatypeConverter
-import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat
object DataTypeMapper {
// milliseconds per day
private val MILLIS_PER_DAY = 24 * 3600 * 1000L
// microseconds per millisecond
private val MICROS_PER_MILLIS = 1000L
+
private[analyzer] def stringToSql(str: String): String =
// Escapes all backslashes, single quotes and new line.
"'" + str
@@ -25,63 +25,78 @@ object DataTypeMapper {
/*
* Convert a value with DataType to a snowflake compatible sql
*/
- private[analyzer] def toSql(value: Any, dataType: Option[DataType]): String = {
- dataType match {
- case None => "NULL"
- case Some(dt) =>
- (value, dt) match {
- case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null =>
- "NULL"
- case (_, IntegerType) if value == null => "NULL :: int"
- case (_, ShortType) if value == null => "NULL :: smallint"
- case (_, ByteType) if value == null => "NULL :: tinyint"
- case (_, LongType) if value == null => "NULL :: bigint"
- case (_, FloatType) if value == null => "NULL :: float"
- case (_, StringType) if value == null => "NULL :: string"
- case (_, DoubleType) if value == null => "NULL :: double"
- case (_, BooleanType) if value == null => "NULL :: boolean"
- case (_, BinaryType) if value == null => "NULL :: binary"
- case _ if value == null => "NULL"
- case (v: String, StringType) => stringToSql(v)
- case (v: Byte, ByteType) => v + s" :: tinyint"
- case (v: Short, ShortType) => v + s" :: smallint"
- case (v: Any, IntegerType) => v + s" :: int"
- case (v: Long, LongType) => v + s" :: bigint"
- case (v: Boolean, BooleanType) => s"$v :: boolean"
- // Float type doesn't have a suffix
- case (v: Float, FloatType) =>
- val castedValue = v match {
- case _ if v.isNaN => "'NaN'"
- case Float.PositiveInfinity => "'Infinity'"
- case Float.NegativeInfinity => "'-Infinity'"
- case _ => s"'$v'"
+ private[analyzer] def toSql(literal: TLiteral): String = {
+ literal match {
+ case Literal(value, dataType) =>
+ (value, dataType) match {
+ case (_, None) => "NULL"
+ case (value, Some(dt)) =>
+ (value, dt) match {
+ case (_, _: ArrayType | _: MapType | _: StructType | GeographyType)
+ if value == null =>
+ "NULL"
+ case (_, IntegerType) if value == null => "NULL :: int"
+ case (_, ShortType) if value == null => "NULL :: smallint"
+ case (_, ByteType) if value == null => "NULL :: tinyint"
+ case (_, LongType) if value == null => "NULL :: bigint"
+ case (_, FloatType) if value == null => "NULL :: float"
+ case (_, StringType) if value == null => "NULL :: string"
+ case (_, DoubleType) if value == null => "NULL :: double"
+ case (_, BooleanType) if value == null => "NULL :: boolean"
+ case (_, BinaryType) if value == null => "NULL :: binary"
+ case _ if value == null => "NULL"
+ case (v: String, StringType) => stringToSql(v)
+ case (v: Byte, ByteType) => v + s" :: tinyint"
+ case (v: Short, ShortType) => v + s" :: smallint"
+ case (v: Any, IntegerType) => v + s" :: int"
+ case (v: Long, LongType) => v + s" :: bigint"
+ case (v: Boolean, BooleanType) => s"$v :: boolean"
+ // Float type doesn't have a suffix
+ case (v: Float, FloatType) =>
+ val castedValue = v match {
+ case _ if v.isNaN => "'NaN'"
+ case Float.PositiveInfinity => "'Infinity'"
+ case Float.NegativeInfinity => "'-Infinity'"
+ case _ => s"'$v'"
+ }
+ s"$castedValue :: FLOAT"
+ case (v: Double, DoubleType) =>
+ v match {
+ case _ if v.isNaN => "'NaN'"
+ case Double.PositiveInfinity => "'Infinity'"
+ case Double.NegativeInfinity => "'-Infinity'"
+ case _ => v + "::DOUBLE"
+ }
+ case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
+ case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
+ case (v: Int, DateType) =>
+ s"DATE '${SnowflakeDateTimeFormat
+ .fromSqlFormat(Utils.DateInputFormat)
+ .format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'"
+ case (v: Long, TimestampType) =>
+ s"TIMESTAMP '${SnowflakeDateTimeFormat
+ .fromSqlFormat(Utils.TimestampInputFormat)
+ .format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'"
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
}
- s"$castedValue :: FLOAT"
- case (v: Double, DoubleType) =>
- v match {
- case _ if v.isNaN => "'NaN'"
- case Double.PositiveInfinity => "'Infinity'"
- case Double.NegativeInfinity => "'-Infinity'"
- case _ => v + "::DOUBLE"
- }
- case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
- case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
- case (v: Int, DateType) =>
- s"DATE '${SnowflakeDateTimeFormat
- .fromSqlFormat(Utils.DateInputFormat)
- .format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'"
- case (v: Long, TimestampType) =>
- s"TIMESTAMP '${SnowflakeDateTimeFormat
- .fromSqlFormat(Utils.TimestampInputFormat)
- .format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'"
- case (v: Array[Byte], BinaryType) =>
- s"'${DatatypeConverter.printHexBinary(v)}' :: binary"
- case _ =>
- throw new UnsupportedOperationException(
- s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
}
+ case arrayLiteral: ArrayLiteral =>
+ if (arrayLiteral.dataTypeOption == Some(BinaryType)) {
+ val bytes = arrayLiteral.value.asInstanceOf[Seq[Byte]].toArray
+ s"'${DatatypeConverter.printHexBinary(bytes)}' :: binary"
+ } else {
+ "ARRAY_CONSTRUCT" + arrayLiteral.elementsLiterals.map(toSql).mkString("(", ", ", ")")
+ }
+ case mapLiteral: MapLiteral =>
+ "OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals
+ .flatMap {
+ case (keyLiteral, valueLiteral) =>
+ Seq(toSql(keyLiteral), toSql(valueLiteral))
+ }
+ .mkString("(", ", ", ")")
}
-
}
private[analyzer] def schemaExpression(dataType: DataType, isNullable: Boolean): String =
diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala
index 69fb3eda..5ecea34b 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala
@@ -2,12 +2,11 @@ package com.snowflake.snowpark.internal.analyzer
import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.types._
+
import java.math.{BigDecimal => JavaBigDecimal}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
-import scala.math.BigDecimal
-
private[snowpark] object Literal {
// Snowflake max precision for decimal is 38
private lazy val bigDecimalRoundContext = new java.math.MathContext(DecimalType.MAX_PRECISION)
@@ -16,7 +15,7 @@ private[snowpark] object Literal {
decimal.round(bigDecimalRoundContext)
}
- def apply(v: Any): Literal = v match {
+ def apply(v: Any): TLiteral = v match {
case i: Int => Literal(i, Option(IntegerType))
case l: Long => Literal(l, Option(LongType))
case d: Double => Literal(d, Option(DoubleType))
@@ -36,7 +35,8 @@ private[snowpark] object Literal {
case t: Timestamp => Literal(DateTimeUtils.javaTimestampToMicros(t), Option(TimestampType))
case ld: LocalDate => Literal(DateTimeUtils.localDateToDays(ld), Option(DateType))
case d: Date => Literal(DateTimeUtils.javaDateToDays(d), Option(DateType))
- case a: Array[Byte] => Literal(a, Option(BinaryType))
+ case s: Seq[Any] => ArrayLiteral(s)
+ case m: Map[Any, Any] => MapLiteral(m)
case null => Literal(null, None)
case v: Literal => v
case _ =>
@@ -45,10 +45,49 @@ private[snowpark] object Literal {
}
-private[snowpark] case class Literal private (value: Any, dataTypeOption: Option[DataType])
- extends Expression {
+private[snowpark] trait TLiteral extends Expression {
+ def value: Any
+ def dataTypeOption: Option[DataType]
+
override def children: Seq[Expression] = Seq.empty
override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression =
this
}
+
+private[snowpark] case class Literal(value: Any, dataTypeOption: Option[DataType])
+ extends TLiteral
+
+private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral {
+ val elementsLiterals: Seq[TLiteral] = value.map(Literal(_))
+ val dataTypeOption = inferArrayType
+
+ private[analyzer] def inferArrayType(): Option[DataType] = {
+ elementsLiterals.flatMap(_.dataTypeOption).distinct match {
+ case Seq() => None
+ case Seq(ByteType) => Some(BinaryType)
+ case Seq(dt) => Some(ArrayType(dt))
+ case Seq(_, _*) => Some(ArrayType(VariantType))
+ }
+ }
+}
+
+private[snowpark] case class MapLiteral(value: Map[Any, Any]) extends TLiteral {
+ val entriesLiterals = value.map { case (k, v) => Literal(k) -> Literal(v) }
+ val dataTypeOption = inferMapType
+
+ private[analyzer] def inferMapType(): Option[MapType] = {
+ entriesLiterals.keys.flatMap(_.dataTypeOption).toSeq.distinct match {
+ case Seq() => None
+ case Seq(StringType) =>
+ val valuesTypes = entriesLiterals.values.flatMap(_.dataTypeOption).toSeq.distinct
+ valuesTypes match {
+ case Seq() => None
+ case Seq(dt) => Some(MapType(StringType, dt))
+ case Seq(_, _*) => Some(MapType(StringType, VariantType))
+ }
+ case _ =>
+ throw ErrorMessage.PLAN_CANNOT_CREATE_LITERAL(value.getClass.getCanonicalName, s"$value")
+ }
+ }
+}
diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala
index a7a5f655..058d00a4 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala
@@ -203,8 +203,8 @@ private object SqlGenerator extends Logging {
case UnspecifiedFrame => ""
case SpecialFrameBoundaryExtractor(str) => str
- case Literal(value, dataType) =>
- DataTypeMapper.toSql(value, dataType)
+ case l: TLiteral =>
+ DataTypeMapper.toSql(l)
case attr: Attribute => quoteName(attr.name)
// unresolved expression
case UnresolvedAttribute(name) => name
diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala
index a6af91aa..6c0cc960 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala
@@ -3,7 +3,7 @@ package com.snowflake.snowpark.internal
import com.snowflake.snowpark.FileOperationCommand._
import com.snowflake.snowpark.Row
import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject}
-import com.snowflake.snowpark.types.{DataType, convertToSFType}
+import com.snowflake.snowpark.types.{ArrayType, DataType, MapType, convertToSFType}
package object analyzer {
// constant string
@@ -446,7 +446,9 @@ package object analyzer {
val types = output.map(_.dataType)
val rows = data.map { row =>
val cells = row.toSeq.zip(types).map {
- case (v, dType) => DataTypeMapper.toSql(v, Option(dType))
+ case (v: Seq[Any], _: ArrayType) => DataTypeMapper.toSql(ArrayLiteral(v))
+ case (v: Map[Any, Any], _: MapType) => DataTypeMapper.toSql(MapLiteral(v))
+ case (v, dType) => DataTypeMapper.toSql(Literal(v, Option(dType)))
}
cells.mkString(_LeftParenthesis, _Comma, _RightParenthesis)
}
@@ -584,7 +586,7 @@ package object analyzer {
private[analyzer] def createTableStatement(
tableName: String,
schema: String,
- replace: Boolean = false,
+ replace: Boolean = true,
error: Boolean = true,
tempType: TempType = TempType.Permanent): String =
_Create + (if (replace) _Or + _Replace else _EmptyString) + tempType + _Table + tableName +
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
index f74dc440..ce76dc79 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
@@ -1,9 +1,17 @@
package com.snowflake.snowpark_test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
import com.snowflake.snowpark_java.*;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
import org.junit.Test;
public class JavaFunctionSuite extends TestBase {
@@ -18,6 +26,64 @@ public void toScalar() {
checkAnswer(df1.select(Functions.col("c1"), Functions.toScalar(df2)), expected, false);
}
+ @Test
+ public void lit() {
+ DataFrame df = getSession().sql("select * from values (1),(2),(3) as T(a)");
+
+ // Empty array is supported
+ Row[] expectedEmptyArray = new Row[3];
+ Arrays.fill(expectedEmptyArray, Row.create("[]"));
+ checkAnswer(df.select(Functions.lit(Collections.EMPTY_LIST)), expectedEmptyArray, false);
+
+ // Empty map is supported
+ Row[] expectedEmptyMap = new Row[3];
+ Arrays.fill(expectedEmptyMap, Row.create("{}"));
+ checkAnswer(df.select(Functions.lit(Collections.EMPTY_MAP)), expectedEmptyMap, false);
+
+ // Array with only bytes should be considered Binary
+ Row[] expectedBinary = new Row[3];
+ Arrays.fill(expectedBinary, Row.create(new byte[] {(byte) 1, (byte) 2, (byte) 3}));
+
+ DataFrame actualBinary = df.select(Functions.lit(List.of((byte) 1, (byte) 2, (byte) 3)));
+
+ checkAnswer(actualBinary, expectedBinary);
+
+ // Array and Map results type are not supported, they are instead always converted to String.
+ // Hence, we need to test by comparing results Strings.
+ Function rowsToString =
+ (Row[] rows) ->
+ Arrays.stream(rows).map((Row row) -> row.getString(0).replaceAll("\n| ", "")).toArray();
+
+ // Array with different types of elements
+ String[] expectedArrays = new String[3];
+ Arrays.fill(expectedArrays, "[1,\"3\",[\"2023-08-25\"]]");
+
+ Row[] actualArraysRows =
+ df.select(Functions.lit(List.of(1, "3", List.of(Date.valueOf("2023-08-25"))))).collect();
+ Object[] actualArrays = rowsToString.apply(actualArraysRows);
+
+ assertEquals(expectedArrays, actualArrays);
+
+ // One or more map keys are not of the String type. Should throw an exception.
+ assertThrows(
+ scala.NotImplementedError.class, () -> df.select(Functions.lit(Map.of("1", 1, 2, 2))));
+
+ // Map with different type of elements
+ String[] expectedMaps = new String[3];
+ Arrays.fill(expectedMaps, "{\"key1\":{\"nestedKey\":42},\"key2\":\"2023-08-25\"}");
+
+ Row[] actualMapsRows =
+ df.select(
+ Functions.lit(
+ Map.of(
+ "key1", Map.of("nestedKey", 42),
+ "key2", Date.valueOf("2023-08-25"))))
+ .collect();
+ Object[] actualMaps = rowsToString.apply(actualMapsRows);
+
+ assertEquals(expectedMaps, actualMaps);
+ }
+
@Test
public void sqlText() {
DataFrame df =
@@ -1457,7 +1523,7 @@ public void regexp_replace() {
Column replacement = Functions.lit("ch");
Row[] expected1 = {Row.create("cht"), Row.create("chg"), Row.create("chuse")};
checkAnswer(
- df.select(Functions.regexp_replace(df.col("a"), pattern, replacement)), expected1, false);
+ df.select(Functions.regexp_replace(df.col("a"), pattern, replacement)), expected1, false);
}
@Test