From fc756a530a7cfa9dde2fb44934ec3c4ec22fa903 Mon Sep 17 00:00:00 2001 From: Jean-Francis Roy Date: Wed, 13 Dec 2023 10:17:26 -0500 Subject: [PATCH 1/4] Apply Jonathan Bergeron's PR #50 from Github's snowpark-java-scala in a Coveo fork. SEARCHREL-547 --- pom.xml | 2 +- .../snowflake/snowpark_java/Functions.java | 2 +- .../snowpark/internal/JavaUtils.scala | 11 ++ .../internal/analyzer/DataTypeMapper.scala | 136 ++++++++++-------- .../snowpark/internal/analyzer/Literal.scala | 50 ++++++- .../internal/analyzer/SqlGenerator.scala | 4 +- .../snowpark/internal/analyzer/package.scala | 6 +- .../snowpark_test/JavaFunctionSuite.java | 72 ++++++++++ 8 files changed, 210 insertions(+), 73 deletions(-) diff --git a/pom.xml b/pom.xml index b4283c76..57f273ba 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ 4.0.0 com.snowflake snowpark - 1.9.0 + 1.9.0-coveo-1 ${project.artifactId} Snowflake's DataFrame API https://www.snowflake.com/ diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 74cc39a8..06477bb0 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -79,7 +79,7 @@ public static Column toScalar(DataFrame df) { * @return The result column */ public static Column lit(Object literal) { - return new Column(com.snowflake.snowpark.functions.lit(literal)); + return new Column(com.snowflake.snowpark.functions.lit(JavaUtils.toScala(literal))); } /** diff --git a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala index 08a92b6b..6d817507 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala @@ -414,4 +414,15 @@ object JavaUtils { } } + def toScala(element: Any): Any = { + import collection.JavaConverters._ + element match { + case map: java.util.Map[_, _] => mapAsScalaMap(map).map { + case (k, v) => toScala(k) -> toScala(v) + }.toMap + case iterable: java.lang.Iterable[_] => iterableAsScalaIterable(iterable).map(toScala) + case iterator: java.util.Iterator[_] => asScalaIterator(iterator).map(toScala) + case _ => element + } + } } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala index 598dd166..1993fd61 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala @@ -1,22 +1,22 @@ package com.snowflake.snowpark.internal.analyzer + import com.snowflake.snowpark.internal.Utils +import com.snowflake.snowpark.types._ +import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat +import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} import java.util.TimeZone -import java.math.{BigDecimal => JBigDecimal} - -import com.snowflake.snowpark.types._ -import com.snowflake.snowpark.types.convertToSFType import javax.xml.bind.DatatypeConverter -import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat object DataTypeMapper { // milliseconds per day private val MILLIS_PER_DAY = 24 * 3600 * 1000L // microseconds per millisecond private val MICROS_PER_MILLIS = 1000L + private[analyzer] def stringToSql(str: String): String = - // Escapes all backslashes, single quotes and new line. + // Escapes all backslashes, single quotes and new line. "'" + str .replaceAll("\\\\", "\\\\\\\\") .replaceAll("'", "''") @@ -25,63 +25,77 @@ object DataTypeMapper { /* * Convert a value with DataType to a snowflake compatible sql */ - private[analyzer] def toSql(value: Any, dataType: Option[DataType]): String = { - dataType match { - case None => "NULL" - case Some(dt) => - (value, dt) match { - case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null => - "NULL" - case (_, IntegerType) if value == null => "NULL :: int" - case (_, ShortType) if value == null => "NULL :: smallint" - case (_, ByteType) if value == null => "NULL :: tinyint" - case (_, LongType) if value == null => "NULL :: bigint" - case (_, FloatType) if value == null => "NULL :: float" - case (_, StringType) if value == null => "NULL :: string" - case (_, DoubleType) if value == null => "NULL :: double" - case (_, BooleanType) if value == null => "NULL :: boolean" - case (_, BinaryType) if value == null => "NULL :: binary" - case _ if value == null => "NULL" - case (v: String, StringType) => stringToSql(v) - case (v: Byte, ByteType) => v + s" :: tinyint" - case (v: Short, ShortType) => v + s" :: smallint" - case (v: Any, IntegerType) => v + s" :: int" - case (v: Long, LongType) => v + s" :: bigint" - case (v: Boolean, BooleanType) => s"$v :: boolean" - // Float type doesn't have a suffix - case (v: Float, FloatType) => - val castedValue = v match { - case _ if v.isNaN => "'NaN'" - case Float.PositiveInfinity => "'Infinity'" - case Float.NegativeInfinity => "'-Infinity'" - case _ => s"'$v'" - } - s"$castedValue :: FLOAT" - case (v: Double, DoubleType) => - v match { - case _ if v.isNaN => "'NaN'" - case Double.PositiveInfinity => "'Infinity'" - case Double.NegativeInfinity => "'-Infinity'" - case _ => v + "::DOUBLE" - } - case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" - case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" - case (v: Int, DateType) => - s"DATE '${SnowflakeDateTimeFormat - .fromSqlFormat(Utils.DateInputFormat) - .format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'" - case (v: Long, TimestampType) => - s"TIMESTAMP '${SnowflakeDateTimeFormat - .fromSqlFormat(Utils.TimestampInputFormat) - .format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'" - case (v: Array[Byte], BinaryType) => - s"'${DatatypeConverter.printHexBinary(v)}' :: binary" - case _ => - throw new UnsupportedOperationException( - s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType") + private[analyzer] def toSql(literal: TLiteral): String = { + literal match { + case Literal(value, dataType) => (value, dataType) match { + case (_, None) => "NULL" + case (value, Some(dt)) => + (value, dt) match { + case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null => + "NULL" + case (_, IntegerType) if value == null => "NULL :: int" + case (_, ShortType) if value == null => "NULL :: smallint" + case (_, ByteType) if value == null => "NULL :: tinyint" + case (_, LongType) if value == null => "NULL :: bigint" + case (_, FloatType) if value == null => "NULL :: float" + case (_, StringType) if value == null => "NULL :: string" + case (_, DoubleType) if value == null => "NULL :: double" + case (_, BooleanType) if value == null => "NULL :: boolean" + case (_, BinaryType) if value == null => "NULL :: binary" + case _ if value == null => "NULL" + case (v: String, StringType) => stringToSql(v) + case (v: Byte, ByteType) => v + s" :: tinyint" + case (v: Short, ShortType) => v + s" :: smallint" + case (v: Any, IntegerType) => v + s" :: int" + case (v: Long, LongType) => v + s" :: bigint" + case (v: Boolean, BooleanType) => s"$v :: boolean" + // Float type doesn't have a suffix + case (v: Float, FloatType) => + val castedValue = v match { + case _ if v.isNaN => "'NaN'" + case Float.PositiveInfinity => "'Infinity'" + case Float.NegativeInfinity => "'-Infinity'" + case _ => s"'$v'" + } + s"$castedValue :: FLOAT" + case (v: Double, DoubleType) => + v match { + case _ if v.isNaN => "'NaN'" + case Double.PositiveInfinity => "'Infinity'" + case Double.NegativeInfinity => "'-Infinity'" + case _ => v + "::DOUBLE" + } + case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" + case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" + case (v: Int, DateType) => + s"DATE '${ + SnowflakeDateTimeFormat + .fromSqlFormat(Utils.DateInputFormat) + .format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT")) + }'" + case (v: Long, TimestampType) => + s"TIMESTAMP '${ + SnowflakeDateTimeFormat + .fromSqlFormat(Utils.TimestampInputFormat) + .format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3) + }'" + case _ => + throw new UnsupportedOperationException( + s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType") + } + } + case arrayLiteral: ArrayLiteral => + if (arrayLiteral.dataTypeOption == Some(BinaryType)) { + val bytes = arrayLiteral.value.asInstanceOf[Seq[Byte]].toArray + s"'${DatatypeConverter.printHexBinary(bytes)}' :: binary" + } else { + "ARRAY_CONSTRUCT" + arrayLiteral.elementsLiterals.map(toSql).mkString("(", ", ", ")") } + case mapLiteral: MapLiteral => + "OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals.flatMap { case (keyLiteral, valueLiteral) => + Seq(toSql(keyLiteral), toSql(valueLiteral)) + }.mkString("(", ", ", ")") } - } private[analyzer] def schemaExpression(dataType: DataType, isNullable: Boolean): String = diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala index 69fb3eda..86804508 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala @@ -2,12 +2,11 @@ package com.snowflake.snowpark.internal.analyzer import com.snowflake.snowpark.internal.ErrorMessage import com.snowflake.snowpark.types._ + import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} import java.time.{Instant, LocalDate} -import scala.math.BigDecimal - private[snowpark] object Literal { // Snowflake max precision for decimal is 38 private lazy val bigDecimalRoundContext = new java.math.MathContext(DecimalType.MAX_PRECISION) @@ -16,7 +15,7 @@ private[snowpark] object Literal { decimal.round(bigDecimalRoundContext) } - def apply(v: Any): Literal = v match { + def apply(v: Any): TLiteral = v match { case i: Int => Literal(i, Option(IntegerType)) case l: Long => Literal(l, Option(LongType)) case d: Double => Literal(d, Option(DoubleType)) @@ -36,7 +35,8 @@ private[snowpark] object Literal { case t: Timestamp => Literal(DateTimeUtils.javaTimestampToMicros(t), Option(TimestampType)) case ld: LocalDate => Literal(DateTimeUtils.localDateToDays(ld), Option(DateType)) case d: Date => Literal(DateTimeUtils.javaDateToDays(d), Option(DateType)) - case a: Array[Byte] => Literal(a, Option(BinaryType)) + case s: Seq[Any] => ArrayLiteral(s) + case m: Map[Any, Any] => MapLiteral(m) case null => Literal(null, None) case v: Literal => v case _ => @@ -45,10 +45,48 @@ private[snowpark] object Literal { } -private[snowpark] case class Literal private (value: Any, dataTypeOption: Option[DataType]) - extends Expression { +private[snowpark] trait TLiteral extends Expression { + def value: Any + def dataTypeOption: Option[DataType] + override def children: Seq[Expression] = Seq.empty override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression = this } + +private[snowpark] case class Literal (value: Any, dataTypeOption: Option[DataType]) extends TLiteral + +private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral { + val elementsLiterals: Seq[TLiteral] = value.map(Literal(_)) + val dataTypeOption = inferArrayType + + private[analyzer] def inferArrayType(): Option[DataType] = { + elementsLiterals.flatMap(_.dataTypeOption).distinct match { + case Seq() => None + case Seq(ByteType) => Some(BinaryType) + case Seq(dt) => Some(ArrayType(dt)) + case Seq(_, _*) => Some(ArrayType(VariantType)) + } + } +} + +private[snowpark] case class MapLiteral(value: Map[Any, Any]) extends TLiteral { + val entriesLiterals = value.map { case (k, v) => Literal(k) -> Literal(v) } + val dataTypeOption = inferMapType + + private[analyzer] def inferMapType(): Option[MapType] = { + entriesLiterals.keys.flatMap(_.dataTypeOption).toSeq.distinct match { + case Seq() => None + case Seq(StringType) => + val valuesTypes = entriesLiterals.values.flatMap(_.dataTypeOption).toSeq.distinct + valuesTypes match { + case Seq() => None + case Seq(dt) => Some(MapType(StringType, dt)) + case Seq(_, _*) => Some(MapType(StringType, VariantType)) + } + case _ => + throw ErrorMessage.PLAN_CANNOT_CREATE_LITERAL(value.getClass.getCanonicalName, s"$value") + } + } +} diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala index a7a5f655..058d00a4 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SqlGenerator.scala @@ -203,8 +203,8 @@ private object SqlGenerator extends Logging { case UnspecifiedFrame => "" case SpecialFrameBoundaryExtractor(str) => str - case Literal(value, dataType) => - DataTypeMapper.toSql(value, dataType) + case l: TLiteral => + DataTypeMapper.toSql(l) case attr: Attribute => quoteName(attr.name) // unresolved expression case UnresolvedAttribute(name) => name diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala index a6af91aa..ca7edc41 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala @@ -3,7 +3,7 @@ package com.snowflake.snowpark.internal import com.snowflake.snowpark.FileOperationCommand._ import com.snowflake.snowpark.Row import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject} -import com.snowflake.snowpark.types.{DataType, convertToSFType} +import com.snowflake.snowpark.types.{ArrayType, DataType, MapType, convertToSFType} package object analyzer { // constant string @@ -446,7 +446,9 @@ package object analyzer { val types = output.map(_.dataType) val rows = data.map { row => val cells = row.toSeq.zip(types).map { - case (v, dType) => DataTypeMapper.toSql(v, Option(dType)) + case (v: Seq[Any], _: ArrayType) => DataTypeMapper.toSql(ArrayLiteral(v)) + case (v: Map[Any, Any], _: MapType) => DataTypeMapper.toSql(MapLiteral(v)) + case (v, dType) => DataTypeMapper.toSql(Literal(v, Option(dType))) } cells.mkString(_LeftParenthesis, _Comma, _RightParenthesis) } diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index f74dc440..6af340dc 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -1,11 +1,23 @@ package com.snowflake.snowpark_test; +import com.snowflake.snowpark.internal.JavaUtils; +import com.snowflake.snowpark.internal.analyzer.Literal; import com.snowflake.snowpark_java.*; import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import jdk.jshell.spi.ExecutionControl; import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + public class JavaFunctionSuite extends TestBase { @Test @@ -17,6 +29,66 @@ public void toScalar() { checkAnswer(df1.select(Functions.col("c1"), Functions.col(df2)), expected, false); checkAnswer(df1.select(Functions.col("c1"), Functions.toScalar(df2)), expected, false); } + + @Test + public void lit() { + DataFrame df = getSession().sql("select * from values (1),(2),(3) as T(a)"); + + // Empty array is supported + Row[] expectedEmptyArray = new Row[3]; + Arrays.fill(expectedEmptyArray, Row.create("[]")); + checkAnswer(df.select(Functions.lit(Collections.EMPTY_LIST)), expectedEmptyArray, false); + + // Empty map is supported + Row[] expectedEmptyMap = new Row[3]; + Arrays.fill(expectedEmptyMap, Row.create("{}")); + checkAnswer(df.select(Functions.lit(Collections.EMPTY_MAP)), expectedEmptyMap, false); + + // Array with only bytes should be considered Binary + Row[] expectedBinary = new Row[3]; + Arrays.fill(expectedBinary, Row.create(new byte[]{(byte) 1, (byte) 2, (byte) 3})); + + DataFrame actualBinary = df.select(Functions.lit(List.of((byte) 1, (byte) 2, (byte) 3))); + + checkAnswer(actualBinary, expectedBinary); + + // Array and Map results type are not supported, they are instead always converted to String. + // Hence, we need to test by comparing results Strings. + Function rowsToString = (Row[] rows) -> Arrays.stream(rows) + .map((Row row) -> row.getString(0).replaceAll("\n| ", "")) + .toArray(); + + // Array with different types of elements + String[] expectedArrays = new String[3]; + Arrays.fill(expectedArrays, "[1,\"3\",[\"2023-08-25\"]]"); + + Row[] actualArraysRows = df.select(Functions.lit(List.of( + 1, + "3", + List.of(Date.valueOf("2023-08-25")) + ))).collect(); + Object[] actualArrays = rowsToString.apply(actualArraysRows); + + assertEquals(expectedArrays, actualArrays); + + // One or more map keys are not of the String type. Should throw an exception. + assertThrows( + scala.NotImplementedError.class, + () -> df.select(Functions.lit(Map.of("1", 1, 2, 2))) + ); + + // Map with different type of elements + String[] expectedMaps = new String[3]; + Arrays.fill(expectedMaps, "{\"key1\":{\"nestedKey\":42},\"key2\":\"2023-08-25\"}"); + + Row[] actualMapsRows = df.select(Functions.lit(Map.of( + "key1", Map.of("nestedKey", 42), + "key2", Date.valueOf("2023-08-25")) + )).collect(); + Object[] actualMaps = rowsToString.apply(actualMapsRows); + + assertEquals(expectedMaps, actualMaps); + } @Test public void sqlText() { From ef103dc2f23764694f7cc1be253cf44e431f6833 Mon Sep 17 00:00:00 2001 From: Jean-Francis Roy Date: Wed, 13 Dec 2023 15:24:14 -0500 Subject: [PATCH 2/4] Add Coveo Cloud maven repo. --- pom.xml | 172 ++++++------------ .../snowflake/snowpark_java/Functions.java | 13 +- .../com/snowflake/snowpark/functions.scala | 1 - .../snowpark/internal/JavaUtils.scala | 7 +- .../internal/analyzer/DataTypeMapper.scala | 119 ++++++------ .../snowpark/internal/analyzer/Literal.scala | 7 +- .../snowpark_test/JavaFunctionSuite.java | 60 +++--- 7 files changed, 155 insertions(+), 224 deletions(-) diff --git a/pom.xml b/pom.xml index 57f273ba..6f866b89 100644 --- a/pom.xml +++ b/pom.xml @@ -443,66 +443,10 @@ - - org.scoverage - scoverage-maven-plugin - ${scoverage.plugin.version} - - - - check - - prepare-package - - - - - org.apache.maven.plugins - maven-gpg-plugin - 1.6 - - - sign-deploy-artifacts - package - - sign - - - - - - net.nicoulaj.maven.plugins - checksum-maven-plugin - 1.10 - - - package - - artifacts - - - - - - SHA-256 - md5 - - - - - - - maven-deploy-plugin - - true - - - - @@ -516,25 +460,6 @@ - org.apache.maven.plugins - maven-assembly-plugin - 3.3.0 - - - generate-tar-zip - none - - - with-dependencies - none - - - fat-test - none - - - - maven-jar-plugin 3.1.0 @@ -551,52 +476,63 @@ - - org.apache.maven.plugins - maven-gpg-plugin - 1.6 - - - sign-deploy-artifacts - none - - - sign-and-deploy-file - deploy - - sign-and-deploy-file - - - target/${project.artifactId}-${project.version}.jar - ossrh - https://oss.sonatype.org/service/local/staging/deploy/maven2 - pom.xml - target/${project.artifactId}-${project.version}-javadoc.jar - ${env.GPG_KEY_ID} - ${env.GPG_KEY_PASSPHRASE} - - - - - - - - org.scoverage - scoverage-maven-plugin - ${scoverage.plugin.version} - - - - report-only - - - - - - + + + + false + + central + libs-release + https://maven.cloud.coveo.com/artifactory/libs-release + + + + false + + snapshots + libs-snapshot + https://maven.cloud.coveo.com/artifactory/libs-snapshot + + + + + + false + + central + plugins-release + https://maven.cloud.coveo.com/artifactory/plugins-release + + + + false + + snapshots + plugins-snapshot + https://maven.cloud.coveo.com/artifactory/plugins-snapshot + + + + + + CoveoCloud + CoveoCloud-releases + https://maven.cloud.coveo.com/artifactory/libs-release-local + + + CoveoCloud + CoveoCloud-snapshots + https://maven.cloud.coveo.com/artifactory/libs-snapshot-local + + + ${project.artifactId} + ${project.artifactId} + ${project.artifactId}/ + + diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 06477bb0..b762eb80 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -2365,14 +2365,13 @@ public static Column regexp_count(Column strExpr, Column pattern) { */ public static Column regexp_replace(Column strExpr, Column pattern) { return new Column( - com.snowflake.snowpark.functions.regexp_replace( - strExpr.toScalaColumn(), pattern.toScalaColumn())); + com.snowflake.snowpark.functions.regexp_replace( + strExpr.toScalaColumn(), pattern.toScalaColumn())); } /** - * Returns the subject with the specified pattern (or all occurrences of the pattern) - * replaced by a replacement string. If no matches are found, returns the original - * subject. + * Returns the subject with the specified pattern (or all occurrences of the pattern) replaced by + * a replacement string. If no matches are found, returns the original subject. * * @param strExpr The input string * @param pattern The pattern @@ -2382,8 +2381,8 @@ public static Column regexp_replace(Column strExpr, Column pattern) { */ public static Column regexp_replace(Column strExpr, Column pattern, Column replacement) { return new Column( - com.snowflake.snowpark.functions.regexp_replace( - strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn())); + com.snowflake.snowpark.functions.regexp_replace( + strExpr.toScalaColumn(), pattern.toScalaColumn(), replacement.toScalaColumn())); } /** diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 49f38593..86ad1969 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -1817,7 +1817,6 @@ object functions { def regexp_replace(strExpr: Column, pattern: Column, replacement: Column): Column = builtin("regexp_replace")(strExpr, pattern, replacement) - /** * Removes all occurrences of a specified strExpr, * and optionally replaces them with replacement. diff --git a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala index 6d817507..9f4b6784 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala @@ -417,9 +417,10 @@ object JavaUtils { def toScala(element: Any): Any = { import collection.JavaConverters._ element match { - case map: java.util.Map[_, _] => mapAsScalaMap(map).map { - case (k, v) => toScala(k) -> toScala(v) - }.toMap + case map: java.util.Map[_, _] => + mapAsScalaMap(map).map { + case (k, v) => toScala(k) -> toScala(v) + }.toMap case iterable: java.lang.Iterable[_] => iterableAsScalaIterable(iterable).map(toScala) case iterator: java.util.Iterator[_] => asScalaIterator(iterator).map(toScala) case _ => element diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala index 1993fd61..c55b45ea 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/DataTypeMapper.scala @@ -16,7 +16,7 @@ object DataTypeMapper { private val MICROS_PER_MILLIS = 1000L private[analyzer] def stringToSql(str: String): String = - // Escapes all backslashes, single quotes and new line. + // Escapes all backslashes, single quotes and new line. "'" + str .replaceAll("\\\\", "\\\\\\\\") .replaceAll("'", "''") @@ -27,63 +27,61 @@ object DataTypeMapper { */ private[analyzer] def toSql(literal: TLiteral): String = { literal match { - case Literal(value, dataType) => (value, dataType) match { - case (_, None) => "NULL" - case (value, Some(dt)) => - (value, dt) match { - case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null => - "NULL" - case (_, IntegerType) if value == null => "NULL :: int" - case (_, ShortType) if value == null => "NULL :: smallint" - case (_, ByteType) if value == null => "NULL :: tinyint" - case (_, LongType) if value == null => "NULL :: bigint" - case (_, FloatType) if value == null => "NULL :: float" - case (_, StringType) if value == null => "NULL :: string" - case (_, DoubleType) if value == null => "NULL :: double" - case (_, BooleanType) if value == null => "NULL :: boolean" - case (_, BinaryType) if value == null => "NULL :: binary" - case _ if value == null => "NULL" - case (v: String, StringType) => stringToSql(v) - case (v: Byte, ByteType) => v + s" :: tinyint" - case (v: Short, ShortType) => v + s" :: smallint" - case (v: Any, IntegerType) => v + s" :: int" - case (v: Long, LongType) => v + s" :: bigint" - case (v: Boolean, BooleanType) => s"$v :: boolean" - // Float type doesn't have a suffix - case (v: Float, FloatType) => - val castedValue = v match { - case _ if v.isNaN => "'NaN'" - case Float.PositiveInfinity => "'Infinity'" - case Float.NegativeInfinity => "'-Infinity'" - case _ => s"'$v'" - } - s"$castedValue :: FLOAT" - case (v: Double, DoubleType) => - v match { - case _ if v.isNaN => "'NaN'" - case Double.PositiveInfinity => "'Infinity'" - case Double.NegativeInfinity => "'-Infinity'" - case _ => v + "::DOUBLE" - } - case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" - case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" - case (v: Int, DateType) => - s"DATE '${ - SnowflakeDateTimeFormat + case Literal(value, dataType) => + (value, dataType) match { + case (_, None) => "NULL" + case (value, Some(dt)) => + (value, dt) match { + case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) + if value == null => + "NULL" + case (_, IntegerType) if value == null => "NULL :: int" + case (_, ShortType) if value == null => "NULL :: smallint" + case (_, ByteType) if value == null => "NULL :: tinyint" + case (_, LongType) if value == null => "NULL :: bigint" + case (_, FloatType) if value == null => "NULL :: float" + case (_, StringType) if value == null => "NULL :: string" + case (_, DoubleType) if value == null => "NULL :: double" + case (_, BooleanType) if value == null => "NULL :: boolean" + case (_, BinaryType) if value == null => "NULL :: binary" + case _ if value == null => "NULL" + case (v: String, StringType) => stringToSql(v) + case (v: Byte, ByteType) => v + s" :: tinyint" + case (v: Short, ShortType) => v + s" :: smallint" + case (v: Any, IntegerType) => v + s" :: int" + case (v: Long, LongType) => v + s" :: bigint" + case (v: Boolean, BooleanType) => s"$v :: boolean" + // Float type doesn't have a suffix + case (v: Float, FloatType) => + val castedValue = v match { + case _ if v.isNaN => "'NaN'" + case Float.PositiveInfinity => "'Infinity'" + case Float.NegativeInfinity => "'-Infinity'" + case _ => s"'$v'" + } + s"$castedValue :: FLOAT" + case (v: Double, DoubleType) => + v match { + case _ if v.isNaN => "'NaN'" + case Double.PositiveInfinity => "'Infinity'" + case Double.NegativeInfinity => "'-Infinity'" + case _ => v + "::DOUBLE" + } + case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" + case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}" + case (v: Int, DateType) => + s"DATE '${SnowflakeDateTimeFormat .fromSqlFormat(Utils.DateInputFormat) - .format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT")) - }'" - case (v: Long, TimestampType) => - s"TIMESTAMP '${ - SnowflakeDateTimeFormat + .format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'" + case (v: Long, TimestampType) => + s"TIMESTAMP '${SnowflakeDateTimeFormat .fromSqlFormat(Utils.TimestampInputFormat) - .format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3) - }'" - case _ => - throw new UnsupportedOperationException( - s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType") - } - } + .format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'" + case _ => + throw new UnsupportedOperationException( + s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType") + } + } case arrayLiteral: ArrayLiteral => if (arrayLiteral.dataTypeOption == Some(BinaryType)) { val bytes = arrayLiteral.value.asInstanceOf[Seq[Byte]].toArray @@ -92,9 +90,12 @@ object DataTypeMapper { "ARRAY_CONSTRUCT" + arrayLiteral.elementsLiterals.map(toSql).mkString("(", ", ", ")") } case mapLiteral: MapLiteral => - "OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals.flatMap { case (keyLiteral, valueLiteral) => - Seq(toSql(keyLiteral), toSql(valueLiteral)) - }.mkString("(", ", ", ")") + "OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals + .flatMap { + case (keyLiteral, valueLiteral) => + Seq(toSql(keyLiteral), toSql(valueLiteral)) + } + .mkString("(", ", ", ")") } } diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala index 86804508..5ecea34b 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/Literal.scala @@ -55,12 +55,13 @@ private[snowpark] trait TLiteral extends Expression { this } -private[snowpark] case class Literal (value: Any, dataTypeOption: Option[DataType]) extends TLiteral +private[snowpark] case class Literal(value: Any, dataTypeOption: Option[DataType]) + extends TLiteral private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral { val elementsLiterals: Seq[TLiteral] = value.map(Literal(_)) val dataTypeOption = inferArrayType - + private[analyzer] def inferArrayType(): Option[DataType] = { elementsLiterals.flatMap(_.dataTypeOption).distinct match { case Seq() => None @@ -74,7 +75,7 @@ private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral { private[snowpark] case class MapLiteral(value: Map[Any, Any]) extends TLiteral { val entriesLiterals = value.map { case (k, v) => Literal(k) -> Literal(v) } val dataTypeOption = inferMapType - + private[analyzer] def inferMapType(): Option[MapType] = { entriesLiterals.keys.flatMap(_.dataTypeOption).toSeq.distinct match { case Seq() => None diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 6af340dc..ce76dc79 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -1,7 +1,8 @@ package com.snowflake.snowpark_test; -import com.snowflake.snowpark.internal.JavaUtils; -import com.snowflake.snowpark.internal.analyzer.Literal; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + import com.snowflake.snowpark_java.*; import java.sql.Date; import java.sql.Time; @@ -11,13 +12,8 @@ import java.util.List; import java.util.Map; import java.util.function.Function; - -import jdk.jshell.spi.ExecutionControl; import org.junit.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThrows; - public class JavaFunctionSuite extends TestBase { @Test @@ -29,7 +25,7 @@ public void toScalar() { checkAnswer(df1.select(Functions.col("c1"), Functions.col(df2)), expected, false); checkAnswer(df1.select(Functions.col("c1"), Functions.toScalar(df2)), expected, false); } - + @Test public void lit() { DataFrame df = getSession().sql("select * from values (1),(2),(3) as T(a)"); @@ -46,47 +42,45 @@ public void lit() { // Array with only bytes should be considered Binary Row[] expectedBinary = new Row[3]; - Arrays.fill(expectedBinary, Row.create(new byte[]{(byte) 1, (byte) 2, (byte) 3})); + Arrays.fill(expectedBinary, Row.create(new byte[] {(byte) 1, (byte) 2, (byte) 3})); DataFrame actualBinary = df.select(Functions.lit(List.of((byte) 1, (byte) 2, (byte) 3))); checkAnswer(actualBinary, expectedBinary); - + // Array and Map results type are not supported, they are instead always converted to String. // Hence, we need to test by comparing results Strings. - Function rowsToString = (Row[] rows) -> Arrays.stream(rows) - .map((Row row) -> row.getString(0).replaceAll("\n| ", "")) - .toArray(); - + Function rowsToString = + (Row[] rows) -> + Arrays.stream(rows).map((Row row) -> row.getString(0).replaceAll("\n| ", "")).toArray(); + // Array with different types of elements String[] expectedArrays = new String[3]; Arrays.fill(expectedArrays, "[1,\"3\",[\"2023-08-25\"]]"); - - Row[] actualArraysRows = df.select(Functions.lit(List.of( - 1, - "3", - List.of(Date.valueOf("2023-08-25")) - ))).collect(); + + Row[] actualArraysRows = + df.select(Functions.lit(List.of(1, "3", List.of(Date.valueOf("2023-08-25"))))).collect(); Object[] actualArrays = rowsToString.apply(actualArraysRows); - + assertEquals(expectedArrays, actualArrays); - + // One or more map keys are not of the String type. Should throw an exception. assertThrows( - scala.NotImplementedError.class, - () -> df.select(Functions.lit(Map.of("1", 1, 2, 2))) - ); - + scala.NotImplementedError.class, () -> df.select(Functions.lit(Map.of("1", 1, 2, 2)))); + // Map with different type of elements String[] expectedMaps = new String[3]; Arrays.fill(expectedMaps, "{\"key1\":{\"nestedKey\":42},\"key2\":\"2023-08-25\"}"); - - Row[] actualMapsRows = df.select(Functions.lit(Map.of( - "key1", Map.of("nestedKey", 42), - "key2", Date.valueOf("2023-08-25")) - )).collect(); + + Row[] actualMapsRows = + df.select( + Functions.lit( + Map.of( + "key1", Map.of("nestedKey", 42), + "key2", Date.valueOf("2023-08-25")))) + .collect(); Object[] actualMaps = rowsToString.apply(actualMapsRows); - + assertEquals(expectedMaps, actualMaps); } @@ -1529,7 +1523,7 @@ public void regexp_replace() { Column replacement = Functions.lit("ch"); Row[] expected1 = {Row.create("cht"), Row.create("chg"), Row.create("chuse")}; checkAnswer( - df.select(Functions.regexp_replace(df.col("a"), pattern, replacement)), expected1, false); + df.select(Functions.regexp_replace(df.col("a"), pattern, replacement)), expected1, false); } @Test From def43d6810925718de7d33ad12ec18d55adceb08 Mon Sep 17 00:00:00 2001 From: Jean-Francis Roy Date: Thu, 14 Dec 2023 18:19:05 -0500 Subject: [PATCH 3/4] Another fix from jbergeron. --- pom.xml | 2 +- .../com/snowflake/snowpark/internal/analyzer/package.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 6f866b89..0f01902c 100644 --- a/pom.xml +++ b/pom.xml @@ -4,7 +4,7 @@ 4.0.0 com.snowflake snowpark - 1.9.0-coveo-1 + 1.9.0-coveo-2 ${project.artifactId} Snowflake's DataFrame API https://www.snowflake.com/ diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala index ca7edc41..6c0cc960 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/package.scala @@ -586,7 +586,7 @@ package object analyzer { private[analyzer] def createTableStatement( tableName: String, schema: String, - replace: Boolean = false, + replace: Boolean = true, error: Boolean = true, tempType: TempType = TempType.Permanent): String = _Create + (if (replace) _Or + _Replace else _EmptyString) + tempType + _Table + tableName + From 863838cc1cacb2a42821773cd1d6756e586f3601 Mon Sep 17 00:00:00 2001 From: jbergeron Date: Thu, 14 Mar 2024 15:43:31 -0400 Subject: [PATCH 4/4] Added regexp_extract_all function. --- .../scala/com/snowflake/snowpark/functions.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 86ad1969..2aaf9051 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -1817,6 +1817,20 @@ object functions { def regexp_replace(strExpr: Column, pattern: Column, replacement: Column): Column = builtin("regexp_replace")(strExpr, pattern, replacement) + + def regexp_extract_all(strExpr: Column, + pattern: Column, + position: Column = lit(1), + occurrence: Column = lit(1), + regex_parameters: Column = lit("c"), + group_num_opt: Option[Column] = None): Column = + builtin("regexp_extract_all")((Seq( + strExpr, + pattern, + position, + occurrence, + regex_parameters) ++ group_num_opt): _*) + /** * Removes all occurrences of a specified strExpr, * and optionally replaces them with replacement.