Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Array and Map literals for the java scala codebase #50

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public static Column toScalar(DataFrame df) {
* @return The result column
*/
public static Column lit(Object literal) {
return new Column(com.snowflake.snowpark.functions.lit(literal));
return new Column(com.snowflake.snowpark.functions.lit(JavaUtils.toScala(literal)));
}

/**
Expand Down
11 changes: 11 additions & 0 deletions src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,4 +414,15 @@ object JavaUtils {
}
}

def toScala(element: Any): Any = {
import collection.JavaConverters._
element match {
case map: java.util.Map[_, _] => mapAsScalaMap(map).map {
case (k, v) => toScala(k) -> toScala(v)
}.toMap
case iterable: java.lang.Iterable[_] => iterableAsScalaIterable(iterable).map(toScala)
case iterator: java.util.Iterator[_] => asScalaIterator(iterator).map(toScala)
case _ => element
}
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utility function that recursively convert a java map/array to a scala map/array

}
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.Utils
import com.snowflake.snowpark.types._
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat

import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.util.TimeZone
import java.math.{BigDecimal => JBigDecimal}

import com.snowflake.snowpark.types._
import com.snowflake.snowpark.types.convertToSFType
import javax.xml.bind.DatatypeConverter
import net.snowflake.client.jdbc.internal.snowflake.common.core.SnowflakeDateTimeFormat

object DataTypeMapper {
// milliseconds per day
private val MILLIS_PER_DAY = 24 * 3600 * 1000L
// microseconds per millisecond
private val MICROS_PER_MILLIS = 1000L

private[analyzer] def stringToSql(str: String): String =
// Escapes all backslashes, single quotes and new line.
// Escapes all backslashes, single quotes and new line.
"'" + str
.replaceAll("\\\\", "\\\\\\\\")
.replaceAll("'", "''")
Expand All @@ -25,63 +25,77 @@ object DataTypeMapper {
/*
* Convert a value with DataType to a snowflake compatible sql
*/
private[analyzer] def toSql(value: Any, dataType: Option[DataType]): String = {
dataType match {
case None => "NULL"
case Some(dt) =>
(value, dt) match {
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null =>
"NULL"
case (_, IntegerType) if value == null => "NULL :: int"
case (_, ShortType) if value == null => "NULL :: smallint"
case (_, ByteType) if value == null => "NULL :: tinyint"
case (_, LongType) if value == null => "NULL :: bigint"
case (_, FloatType) if value == null => "NULL :: float"
case (_, StringType) if value == null => "NULL :: string"
case (_, DoubleType) if value == null => "NULL :: double"
case (_, BooleanType) if value == null => "NULL :: boolean"
case (_, BinaryType) if value == null => "NULL :: binary"
case _ if value == null => "NULL"
case (v: String, StringType) => stringToSql(v)
case (v: Byte, ByteType) => v + s" :: tinyint"
case (v: Short, ShortType) => v + s" :: smallint"
case (v: Any, IntegerType) => v + s" :: int"
case (v: Long, LongType) => v + s" :: bigint"
case (v: Boolean, BooleanType) => s"$v :: boolean"
// Float type doesn't have a suffix
case (v: Float, FloatType) =>
val castedValue = v match {
case _ if v.isNaN => "'NaN'"
case Float.PositiveInfinity => "'Infinity'"
case Float.NegativeInfinity => "'-Infinity'"
case _ => s"'$v'"
}
s"$castedValue :: FLOAT"
case (v: Double, DoubleType) =>
v match {
case _ if v.isNaN => "'NaN'"
case Double.PositiveInfinity => "'Infinity'"
case Double.NegativeInfinity => "'-Infinity'"
case _ => v + "::DOUBLE"
}
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: Int, DateType) =>
s"DATE '${SnowflakeDateTimeFormat
.fromSqlFormat(Utils.DateInputFormat)
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))}'"
case (v: Long, TimestampType) =>
s"TIMESTAMP '${SnowflakeDateTimeFormat
.fromSqlFormat(Utils.TimestampInputFormat)
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)}'"
case (v: Array[Byte], BinaryType) =>
s"'${DatatypeConverter.printHexBinary(v)}' :: binary"
case _ =>
throw new UnsupportedOperationException(
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
private[analyzer] def toSql(literal: TLiteral): String = {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now takes an instance of TLIteral, it is match against the children classes MapLiteral, ArrayLiteral and Literal. The Literal case contains the original code without change except for BinaryType that is moved inside of ArrayLiteral.

literal match {
case Literal(value, dataType) => (value, dataType) match {
case (_, None) => "NULL"
case (value, Some(dt)) =>
(value, dt) match {
case (_, _: ArrayType | _: MapType | _: StructType | GeographyType) if value == null =>
"NULL"
case (_, IntegerType) if value == null => "NULL :: int"
case (_, ShortType) if value == null => "NULL :: smallint"
case (_, ByteType) if value == null => "NULL :: tinyint"
case (_, LongType) if value == null => "NULL :: bigint"
case (_, FloatType) if value == null => "NULL :: float"
case (_, StringType) if value == null => "NULL :: string"
case (_, DoubleType) if value == null => "NULL :: double"
case (_, BooleanType) if value == null => "NULL :: boolean"
case (_, BinaryType) if value == null => "NULL :: binary"
case _ if value == null => "NULL"
case (v: String, StringType) => stringToSql(v)
case (v: Byte, ByteType) => v + s" :: tinyint"
case (v: Short, ShortType) => v + s" :: smallint"
case (v: Any, IntegerType) => v + s" :: int"
case (v: Long, LongType) => v + s" :: bigint"
case (v: Boolean, BooleanType) => s"$v :: boolean"
// Float type doesn't have a suffix
case (v: Float, FloatType) =>
val castedValue = v match {
case _ if v.isNaN => "'NaN'"
case Float.PositiveInfinity => "'Infinity'"
case Float.NegativeInfinity => "'-Infinity'"
case _ => s"'$v'"
}
s"$castedValue :: FLOAT"
case (v: Double, DoubleType) =>
v match {
case _ if v.isNaN => "'NaN'"
case Double.PositiveInfinity => "'Infinity'"
case Double.NegativeInfinity => "'-Infinity'"
case _ => v + "::DOUBLE"
}
case (v: BigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: JBigDecimal, t: DecimalType) => v + s" :: ${number(t.precision, t.scale)}"
case (v: Int, DateType) =>
s"DATE '${
SnowflakeDateTimeFormat
.fromSqlFormat(Utils.DateInputFormat)
.format(new Date(v * MILLIS_PER_DAY), TimeZone.getTimeZone("GMT"))
}'"
case (v: Long, TimestampType) =>
s"TIMESTAMP '${
SnowflakeDateTimeFormat
.fromSqlFormat(Utils.TimestampInputFormat)
.format(new Timestamp(v / MICROS_PER_MILLIS), TimeZone.getDefault, 3)
}'"
case _ =>
throw new UnsupportedOperationException(
s"Unsupported datatype by ToSql: ${value.getClass.getName} => $dataType")
}
}
case arrayLiteral: ArrayLiteral =>
if (arrayLiteral.dataTypeOption == Some(BinaryType)) {
val bytes = arrayLiteral.value.asInstanceOf[Seq[Byte]].toArray
s"'${DatatypeConverter.printHexBinary(bytes)}' :: binary"
} else {
"ARRAY_CONSTRUCT" + arrayLiteral.elementsLiterals.map(toSql).mkString("(", ", ", ")")
}
case mapLiteral: MapLiteral =>
"OBJECT_CONSTRUCT" + mapLiteral.entriesLiterals.flatMap { case (keyLiteral, valueLiteral) =>
Seq(toSql(keyLiteral), toSql(valueLiteral))
}.mkString("(", ", ", ")")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the ARRAY_CONSTRUCT AND OBJECT_CONSTRUCT here and recursively call toSql to populate the arguments of those methods.

}

}

private[analyzer] def schemaExpression(dataType: DataType, isNullable: Boolean): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package com.snowflake.snowpark.internal.analyzer

import com.snowflake.snowpark.internal.ErrorMessage
import com.snowflake.snowpark.types._

import java.math.{BigDecimal => JavaBigDecimal}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}

import scala.math.BigDecimal

private[snowpark] object Literal {
// Snowflake max precision for decimal is 38
private lazy val bigDecimalRoundContext = new java.math.MathContext(DecimalType.MAX_PRECISION)
Expand All @@ -16,7 +15,7 @@ private[snowpark] object Literal {
decimal.round(bigDecimalRoundContext)
}

def apply(v: Any): Literal = v match {
def apply(v: Any): TLiteral = v match {
case i: Int => Literal(i, Option(IntegerType))
case l: Long => Literal(l, Option(LongType))
case d: Double => Literal(d, Option(DoubleType))
Expand All @@ -36,7 +35,8 @@ private[snowpark] object Literal {
case t: Timestamp => Literal(DateTimeUtils.javaTimestampToMicros(t), Option(TimestampType))
case ld: LocalDate => Literal(DateTimeUtils.localDateToDays(ld), Option(DateType))
case d: Date => Literal(DateTimeUtils.javaDateToDays(d), Option(DateType))
case a: Array[Byte] => Literal(a, Option(BinaryType))
case s: Seq[Any] => ArrayLiteral(s)
case m: Map[Any, Any] => MapLiteral(m)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We build the Array and Map Literal here, the data types are infered recursively in the classes themselves.

case null => Literal(null, None)
case v: Literal => v
case _ =>
Expand All @@ -45,10 +45,48 @@ private[snowpark] object Literal {

}

private[snowpark] case class Literal private (value: Any, dataTypeOption: Option[DataType])
extends Expression {
private[snowpark] trait TLiteral extends Expression {
def value: Any
def dataTypeOption: Option[DataType]

override def children: Seq[Expression] = Seq.empty

override protected def createAnalyzedExpression(analyzedChildren: Seq[Expression]): Expression =
this
}

private[snowpark] case class Literal (value: Any, dataTypeOption: Option[DataType]) extends TLiteral

private[snowpark] case class ArrayLiteral(value: Seq[Any]) extends TLiteral {
val elementsLiterals: Seq[TLiteral] = value.map(Literal(_))
val dataTypeOption = inferArrayType

private[analyzer] def inferArrayType(): Option[DataType] = {
elementsLiterals.flatMap(_.dataTypeOption).distinct match {
case Seq() => None
case Seq(ByteType) => Some(BinaryType)
case Seq(dt) => Some(ArrayType(dt))
case Seq(_, _*) => Some(ArrayType(VariantType))
}
}
}

private[snowpark] case class MapLiteral(value: Map[Any, Any]) extends TLiteral {
val entriesLiterals = value.map { case (k, v) => Literal(k) -> Literal(v) }
val dataTypeOption = inferMapType

private[analyzer] def inferMapType(): Option[MapType] = {
entriesLiterals.keys.flatMap(_.dataTypeOption).toSeq.distinct match {
case Seq() => None
case Seq(StringType) =>
val valuesTypes = entriesLiterals.values.flatMap(_.dataTypeOption).toSeq.distinct
valuesTypes match {
case Seq() => None
case Seq(dt) => Some(MapType(StringType, dt))
case Seq(_, _*) => Some(MapType(StringType, VariantType))
}
case _ =>
throw ErrorMessage.PLAN_CANNOT_CREATE_LITERAL(value.getClass.getCanonicalName, s"$value")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All keys of a Map must be of String type otherwise an exception is thrown. Maybe the exception could be more precise here ?

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ private object SqlGenerator extends Logging {
case UnspecifiedFrame => ""
case SpecialFrameBoundaryExtractor(str) => str

case Literal(value, dataType) =>
DataTypeMapper.toSql(value, dataType)
case l: TLiteral =>
DataTypeMapper.toSql(l)
case attr: Attribute => quoteName(attr.name)
// unresolved expression
case UnresolvedAttribute(name) => name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.snowflake.snowpark.internal
import com.snowflake.snowpark.FileOperationCommand._
import com.snowflake.snowpark.Row
import com.snowflake.snowpark.internal.Utils.{TempObjectType, randomNameForTempObject}
import com.snowflake.snowpark.types.{DataType, convertToSFType}
import com.snowflake.snowpark.types.{ArrayType, DataType, MapType, convertToSFType}

package object analyzer {
// constant string
Expand Down Expand Up @@ -446,7 +446,9 @@ package object analyzer {
val types = output.map(_.dataType)
val rows = data.map { row =>
val cells = row.toSeq.zip(types).map {
case (v, dType) => DataTypeMapper.toSql(v, Option(dType))
case (v: Seq[Any], _: ArrayType) => DataTypeMapper.toSql(ArrayLiteral(v))
case (v: Map[Any, Any], _: MapType) => DataTypeMapper.toSql(MapLiteral(v))
case (v, dType) => DataTypeMapper.toSql(Literal(v, Option(dType)))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the code is a bit puzzling me, I think Array and Map are not supported by the Values function, so maybe I should throw something here ?

In any case, the design of the current code is a bit flawed here. Take for example a list of ints that represent DateType, when creating the ArrayLiteral the type is not taken into consideration and is instead infered (so Int). When converting to SQL we will not call the DATE function.

}
cells.mkString(_LeftParenthesis, _Comma, _RightParenthesis)
}
Expand Down
72 changes: 72 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
package com.snowflake.snowpark_test;

import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark.internal.analyzer.Literal;
import com.snowflake.snowpark_java.*;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import jdk.jshell.spi.ExecutionControl;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;

public class JavaFunctionSuite extends TestBase {

@Test
Expand All @@ -17,6 +29,66 @@ public void toScalar() {
checkAnswer(df1.select(Functions.col("c1"), Functions.col(df2)), expected, false);
checkAnswer(df1.select(Functions.col("c1"), Functions.toScalar(df2)), expected, false);
}

@Test
public void lit() {
DataFrame df = getSession().sql("select * from values (1),(2),(3) as T(a)");

// Empty array is supported
Row[] expectedEmptyArray = new Row[3];
Arrays.fill(expectedEmptyArray, Row.create("[]"));
checkAnswer(df.select(Functions.lit(Collections.EMPTY_LIST)), expectedEmptyArray, false);

// Empty map is supported
Row[] expectedEmptyMap = new Row[3];
Arrays.fill(expectedEmptyMap, Row.create("{}"));
checkAnswer(df.select(Functions.lit(Collections.EMPTY_MAP)), expectedEmptyMap, false);

// Array with only bytes should be considered Binary
Row[] expectedBinary = new Row[3];
Arrays.fill(expectedBinary, Row.create(new byte[]{(byte) 1, (byte) 2, (byte) 3}));

DataFrame actualBinary = df.select(Functions.lit(List.of((byte) 1, (byte) 2, (byte) 3)));

checkAnswer(actualBinary, expectedBinary);

// Array and Map results type are not supported, they are instead always converted to String.
// Hence, we need to test by comparing results Strings.
Function<Row[], Object[]> rowsToString = (Row[] rows) -> Arrays.stream(rows)
.map((Row row) -> row.getString(0).replaceAll("\n| ", ""))
.toArray();

// Array with different types of elements
String[] expectedArrays = new String[3];
Arrays.fill(expectedArrays, "[1,\"3\",[\"2023-08-25\"]]");

Row[] actualArraysRows = df.select(Functions.lit(List.of(
1,
"3",
List.of(Date.valueOf("2023-08-25"))
))).collect();
Object[] actualArrays = rowsToString.apply(actualArraysRows);

assertEquals(expectedArrays, actualArrays);

// One or more map keys are not of the String type. Should throw an exception.
assertThrows(
scala.NotImplementedError.class,
() -> df.select(Functions.lit(Map.of("1", 1, 2, 2)))
);

// Map with different type of elements
String[] expectedMaps = new String[3];
Arrays.fill(expectedMaps, "{\"key1\":{\"nestedKey\":42},\"key2\":\"2023-08-25\"}");

Row[] actualMapsRows = df.select(Functions.lit(Map.of(
"key1", Map.of("nestedKey", 42),
"key2", Date.valueOf("2023-08-25"))
)).collect();
Object[] actualMaps = rowsToString.apply(actualMapsRows);

assertEquals(expectedMaps, actualMaps);
}

@Test
public void sqlText() {
Expand Down