Skip to content

Commit

Permalink
Apply Jonathan Bergeron's PR snowflakedb#50 from Github's snowpark-ja…
Browse files Browse the repository at this point in the history
…va-scala in a Coveo fork. SEARCHREL-547
  • Loading branch information
jeanfrancisroy committed Dec 13, 2023
1 parent eb73f8a commit fc756a5
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 73 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.snowflake</groupId>
<artifactId>snowpark</artifactId>
<version>1.9.0</version>
<version>1.9.0-coveo-1</version>
<name>${project.artifactId}</name>
<description>Snowflake's DataFrame API</description>
<url>https://www.snowflake.com/</url>
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public static Column toScalar(DataFrame df) {
* @return The result column
*/
public static Column lit(Object literal) {
return new Column(com.snowflake.snowpark.functions.lit(literal));
return new Column(com.snowflake.snowpark.functions.lit(JavaUtils.toScala(literal)));
}

/**
Expand Down
11 changes: 11 additions & 0 deletions src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,15 @@ object JavaUtils {
}
}

def toScala(element: Any): Any = {
import collection.JavaConverters._
element match {
case map: java.util.Map[_, _] => mapAsScalaMap(map).map {
case (k, v) => toScala(k) -> toScala(v)
}.toMap
case iterable: java.lang.Iterable[_] => iterableAsScalaIterable(iterable).map(toScala)
case iterator: java.util.Iterator[_] => asScalaIterator(iterator).map(toScala)
case _ => element
}
}
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.Utils
import com.snowflake.snowpark.types._
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat

import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.util.TimeZone
import java.math.{BigDecimal => JBigDecimal}

import com.snowflake.snowpark.types._
import com.snowflake.snowpark.types.convertToSFType
import javax.xml.bind.DatatypeConverter
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat

object DataTypeMapper {
// milliseconds per day
private val MILLIS_PER_DAY = 24 * 3600 * 1000L
// microseconds per millisecond
private val MICROS_PER_MILLIS = 1000L

private[analyzer] def stringToSql(str: String): String =
// Escapes all backslashes, single quotes and new line.
// Escapes all backslashes, single quotes and new line.
"'" + str
.replaceAll("\\\\", "\\\\\\\\")
.replaceAll("'", "''")
Expand All @@ -25,63 +25,77 @@ object DataTypeMapper {
/*
* Convert a value with DataType to a snowflake compatible sql
*/
private[analyzer] def toSql(value: Any, dataType: Option[DataType]): String = {
dataType match {
case None => "NULL"
case Some(dt) =>
(value, dt) match {
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null =>
"NULL"
case (_, IntegerType) if value == null => "NULL :: int"
case (_, ShortType) if value == null => "NULL :: smallint"
case (_, ByteType) if value == null => "NULL :: tinyint"
case (_, LongType) if value == null => "NULL :: bigint"
case (_, FloatType) if value == null => "NULL :: float"
case (_, StringType) if value == null => "NULL :: string"
case (_, DoubleType) if value == null => "NULL :: double"
case (_, BooleanType) if value == null => "NULL :: boolean"
case (_, BinaryType) if value == null => "NULL :: binary"
case _ if value == null => "NULL"
case (v: String, StringType) => stringToSql(v)
case (v: Byte, ByteType) => v + s" :: tinyint"
case (v: Short, ShortType) => v + s" :: smallint"
case (v: Any, IntegerType) => v + s" :: int"
case (v: Long, LongType) => v + s" :: bigint"
case (v: Boolean, BooleanType) => s"$v :: boolean"
// Float type doesn't have a suffix
case (v: Float, FloatType) =>
val castedValue = v match {
case _ if v.isNaN => "'NaN'"
case Float.PositiveInfinity => "'Infinity'"
case Float.NegativeInfinity => "'-Infinity'"
case _ => s"'$v'"
}
s"$castedValue :: FLOAT"
case (v: Double, DoubleType) =>
v match {
case _ if v.isNaN => "'NaN'"
case Double.PositiveInfinity => "'Infinity'"
case Double.NegativeInfinity => "'-Infinity'"
case _ => v + "::DOUBLE"
}
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: Int, DateType) =>
s"DATE '${SnowflakeDateTimeFormat
.fromSqlFormat(Utils.DateInputFormat)
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'"
case (v: Long, TimestampType) =>
s"TIMESTAMP '${SnowflakeDateTimeFormat
.fromSqlFormat(Utils.TimestampInputFormat)
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'"
case (v: Array[Byte], BinaryType) =>
s"'${DatatypeConverter.printHexBinary(v)}' :: binary"
case _ =>
throw new UnsupportedOperationException(
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
private[analyzer] def toSql(literal: TLiteral): String = {
literal match {
case Literal(value, dataType) => (value, dataType) match {
case (_, None) => "NULL"
case (value, Some(dt)) =>
(value, dt) match {
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null =>
"NULL"
case (_, IntegerType) if value == null => "NULL :: int"
case (_, ShortType) if value == null => "NULL :: smallint"
case (_, ByteType) if value == null => "NULL :: tinyint"
case (_, LongType) if value == null => "NULL :: bigint"
case (_, FloatType) if value == null => "NULL :: float"
case (_, StringType) if value == null => "NULL :: string"
case (_, DoubleType) if value == null => "NULL :: double"
case (_, BooleanType) if value == null => "NULL :: boolean"
case (_, BinaryType) if value == null => "NULL :: binary"
case _ if value == null => "NULL"
case (v: String, StringType) => stringToSql(v)
case (v: Byte, ByteType) => v + s" :: tinyint"
case (v: Short, ShortType) => v + s" :: smallint"
case (v: Any, IntegerType) => v + s" :: int"
case (v: Long, LongType) => v + s" :: bigint"
case (v: Boolean, BooleanType) => s"$v :: boolean"
// Float type doesn't have a suffix
case (v: Float, FloatType) =>
val castedValue = v match {
case _ if v.isNaN => "'NaN'"
case Float.PositiveInfinity => "'Infinity'"
case Float.NegativeInfinity => "'-Infinity'"
case _ => s"'$v'"
}
s"$castedValue :: FLOAT"
case (v: Double, DoubleType) =>
v match {
case _ if v.isNaN => "'NaN'"
case Double.PositiveInfinity => "'Infinity'"
case Double.NegativeInfinity => "'-Infinity'"
case _ => v + "::DOUBLE"
}
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: Int, DateType) =>
s"DATE '${
SnowflakeDateTimeFormat
.fromSqlFormat(Utils.DateInputFormat)
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))
}'"
case (v: Long, TimestampType) =>
s"TIMESTAMP '${
SnowflakeDateTimeFormat
.fromSqlFormat(Utils.TimestampInputFormat)
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)
}'"
case _ =>
throw new UnsupportedOperationException(
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
}
}
case arrayLiteral: ArrayLiteral =>
if (arrayLiteral.dataTypeOption == Some(BinaryType)) {
val bytes = arrayLiteral.value.asInstanceOf[Seq[Byte]].toArray
s"'${DatatypeConverter.printHexBinary(bytes)}' :: binary"
} else {
"ARRAY_CONSTRUCT" + arrayLiteral.elementsLiterals.map(toSql).mkString("(", ", ", ")")
}
case mapLiteral: MapLiteral =>
"OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals.flatMap { case (keyLiteral, valueLiteral) =>
Seq(toSql(keyLiteral), toSql(valueLiteral))
}.mkString("(", ", ", ")")
}

}

private[analyzer] def schemaExpression(dataType: DataType, isNullable: Boolean): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.types._

import java.math.{BigDecimal => JavaBigDecimal}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}

import scala.math.BigDecimal

private[snowpark] object Literal {
// Snowflake max precision for decimal is 38
private lazy val bigDecimalRoundContext = new java.math.MathContext(DecimalType.MAX_PRECISION)
Expand All @@ -16,7 +15,7 @@ private[snowpark] object Literal {
decimal.round(bigDecimalRoundContext)
}

def apply(v: Any): Literal = v match {
def apply(v: Any): TLiteral = v match {
case i: Int => Literal(i, Option(IntegerType))
case l: Long => Literal(l, Option(LongType))
case d: Double => Literal(d, Option(DoubleType))
Expand All @@ -36,7 +35,8 @@ private[snowpark] object Literal {
case t: Timestamp => Literal(DateTimeUtils.javaTimestampToMicros(t), Option(TimestampType))
case ld: LocalDate => Literal(DateTimeUtils.localDateToDays(ld), Option(DateType))
case d: Date => Literal(DateTimeUtils.javaDateToDays(d), Option(DateType))
case a: Array[Byte] => Literal(a, Option(BinaryType))
case s: Seq[Any] => ArrayLiteral(s)
case m: Map[Any, Any] => MapLiteral(m)
case null => Literal(null, None)
case v: Literal => v
case _ =>
Expand All @@ -45,10 +45,48 @@ private[snowpark] object Literal {

}

private[snowpark] case class Literal private (value: Any, dataTypeOption: Option[DataType])
extends Expression {
private[snowpark] trait TLiteral extends Expression {
def value: Any
def dataTypeOption: Option[DataType]

override def children: Seq[Expression] = Seq.empty

override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression =
this
}

private[snowpark] case class Literal (value: Any, dataTypeOption: Option[DataType]) extends TLiteral

private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral {
val elementsLiterals: Seq[TLiteral] = value.map(Literal(_))
val dataTypeOption = inferArrayType

private[analyzer] def inferArrayType(): Option[DataType] = {
elementsLiterals.flatMap(_.dataTypeOption).distinct match {
case Seq() => None
case Seq(ByteType) => Some(BinaryType)
case Seq(dt) => Some(ArrayType(dt))
case Seq(_, _*) => Some(ArrayType(VariantType))
}
}
}

private[snowpark] case class MapLiteral(value: Map[Any, Any]) extends TLiteral {
val entriesLiterals = value.map { case (k, v) => Literal(k) -> Literal(v) }
val dataTypeOption = inferMapType

private[analyzer] def inferMapType(): Option[MapType] = {
entriesLiterals.keys.flatMap(_.dataTypeOption).toSeq.distinct match {
case Seq() => None
case Seq(StringType) =>
val valuesTypes = entriesLiterals.values.flatMap(_.dataTypeOption).toSeq.distinct
valuesTypes match {
case Seq() => None
case Seq(dt) => Some(MapType(StringType, dt))
case Seq(_, _*) => Some(MapType(StringType, VariantType))
}
case _ =>
throw ErrorMessage.PLAN_CANNOT_CREATE_LITERAL(value.getClass.getCanonicalName, s"$value")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ private object SqlGenerator extends Logging {
case UnspecifiedFrame => ""
case SpecialFrameBoundaryExtractor(str) => str

case Literal(value, dataType) =>
DataTypeMapper.toSql(value, dataType)
case l: TLiteral =>
DataTypeMapper.toSql(l)
case attr: Attribute => quoteName(attr.name)
// unresolved expression
case UnresolvedAttribute(name) => name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.snowflake.snowpark.internal
import com.snowflake.snowpark.FileOperationCommand._
import com.snowflake.snowpark.Row
import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject}
import com.snowflake.snowpark.types.{DataType, convertToSFType}
import com.snowflake.snowpark.types.{ArrayType, DataType, MapType, convertToSFType}

package object analyzer {
// constant string
Expand Down Expand Up @@ -446,7 +446,9 @@ package object analyzer {
val types = output.map(_.dataType)
val rows = data.map { row =>
val cells = row.toSeq.zip(types).map {
case (v, dType) => DataTypeMapper.toSql(v, Option(dType))
case (v: Seq[Any], _: ArrayType) => DataTypeMapper.toSql(ArrayLiteral(v))
case (v: Map[Any, Any], _: MapType) => DataTypeMapper.toSql(MapLiteral(v))
case (v, dType) => DataTypeMapper.toSql(Literal(v, Option(dType)))
}
cells.mkString(_LeftParenthesis, _Comma, _RightParenthesis)
}
Expand Down
Loading

0 comments on commit fc756a5

Please sign in to comment.