Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed Jun 18, 2024
2 parents b390d45 + d994b92 commit 3eb4e71
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
repos:
- repo: [email protected]:snowflakedb/casec_precommit.git
rev: HEAD
rev: v1.35.4
hooks:
- id: secret-scanner
24 changes: 17 additions & 7 deletions src/main/scala/com/snowflake/snowpark/types/Variant.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,23 @@ class Variant private[snowpark] (
* @since 0.2.0
*/
def this(str: String) =
this({
try {
MAPPER.readTree(str)
} catch {
case _: Exception => JsonNodeFactory.instance.textNode(str)
}
}, VariantTypes.String)
this(
{
try {
// `ObjectMapper` only reads the first token from
// the input string but not the whole string.
// For example, It can successfully
// convert "null dummy" to `null` value without reporting error.
if (str.toLowerCase().startsWith("null") && str != "null") {
JsonNodeFactory.instance.textNode(str)
} else {
MAPPER.readTree(str)
}
} catch {
case _: Exception => JsonNodeFactory.instance.textNode(str)
}
},
VariantTypes.String)

/**
* Creates a Variant from binary value
Expand Down
16 changes: 16 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.snowflake.snowpark_test
import com.snowflake.snowpark.{Row, SNTestBase, TestUtils}
import com.snowflake.snowpark.types._
import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.internal.Utils

import java.sql.{Date, Time, Timestamp}
import java.util.TimeZone
Expand Down Expand Up @@ -588,4 +589,19 @@ class DataTypeSuite extends SNTestBase {
|""".stripMargin)
// scalastyle:on
}

test("Variant containing word null in the text") {
import session.implicits._
var variant = new Variant("null string starts with null")
var df = Seq(variant).toDF("a")
checkAnswer(df, Row("\"null string starts with null\""))

variant = new Variant("string with null in the middle")
df = Seq(variant).toDF("a")
checkAnswer(df, Row("\"string with null in the middle\""))

variant = new Variant("string with null in the end null")
df = Seq(variant).toDF("a")
checkAnswer(df, Row("\"string with null in the end null\""))
}
}
41 changes: 16 additions & 25 deletions src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ class UDTFSuite extends TestData {
Seq(
Row("w1 w2", "g1", "w2", 1),
Row("w1 w2", "g1", "w1", 1),
Row("w1 w1 w1", "g2", "w1", 3)),
false)
Row("w1 w1 w1", "g2", "w1", 3)))
} finally {
runQuery(s"drop function if exists $funcName(STRING)", session)
}
Expand Down Expand Up @@ -145,8 +144,7 @@ class UDTFSuite extends TestData {
|""".stripMargin)
checkAnswer(
df1,
Seq(Row("w3", 6), Row("w2", 4), Row("w1", 2), Row("w3", 6), Row("w2", 4), Row("w1", 2)),
false)
Seq(Row("w3", 6), Row("w2", 4), Row("w1", 2), Row("w3", 6), Row("w2", 4), Row("w1", 2)))

// Call the UDTF with funcName and named parameters, result should be the same
val df2 = session
Expand All @@ -162,8 +160,7 @@ class UDTFSuite extends TestData {
|""".stripMargin)
checkAnswer(
df2,
Seq(Row(6, "w3"), Row(4, "w2"), Row(2, "w1"), Row(6, "w3"), Row(4, "w2"), Row(2, "w1")),
false)
Seq(Row(6, "w3"), Row(4, "w2"), Row(2, "w1"), Row(6, "w3"), Row(4, "w2"), Row(2, "w1")))

// scalastyle:off
// Use UDTF with table join
Expand Down Expand Up @@ -199,8 +196,7 @@ class UDTFSuite extends TestData {
Row(null, null, "g2", 1),
Row(null, null, "w2", 1),
Row(null, null, "g1", 1),
Row(null, null, "w1", 4)),
false)
Row(null, null, "w1", 4)))

// Use UDTF with table function + over partition
val df4 = session.sql(
Expand All @@ -217,8 +213,7 @@ class UDTFSuite extends TestData {
Row("w1 w1 w1", "g2", "g2", 1),
Row("w1 w1 w1", "g2", "w1", 3),
Row(null, "g2", "g2", 1),
Row(null, "g2", "w1", 3)),
false)
Row(null, "g2", "w1", 3)))
} finally {
runQuery(s"drop function if exists $funcName(VARCHAR,VARCHAR)", session)
}
Expand All @@ -245,15 +240,15 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df1, Seq(Row(10), Row(11), Row(12), Row(13), Row(14)), false)
checkAnswer(df1, Seq(Row(10), Row(11), Row(12), Row(13), Row(14)))

val df2 = session.tableFunction(TableFunction(funcName), lit(20), lit(5))
assert(
getSchemaString(df2.schema) ==
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df2, Seq(Row(20), Row(21), Row(22), Row(23), Row(24)), false)
checkAnswer(df2, Seq(Row(20), Row(21), Row(22), Row(23), Row(24)))

val df3 = session
.tableFunction(tableFunction, Map("arg1" -> lit(30), "arg2" -> lit(5)))
Expand All @@ -262,7 +257,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)))
} finally {
runQuery(s"drop function if exists $funcName(NUMBER, NUMBER)", session)
}
Expand Down Expand Up @@ -377,15 +372,15 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df1, Seq(Row(10), Row(11), Row(20), Row(21), Row(22), Row(23)), false)
checkAnswer(df1, Seq(Row(10), Row(11), Row(20), Row(21), Row(22), Row(23)))

val df2 = session.tableFunction(TableFunction(funcName), sourceDF("b"), sourceDF("c"))
assert(
getSchemaString(df2.schema) ==
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df2, Seq(Row(100), Row(101), Row(200), Row(201), Row(202), Row(203)), false)
checkAnswer(df2, Seq(Row(100), Row(101), Row(200), Row(201), Row(202), Row(203)))

// Check table function with df column arguments as Map
val sourceDF2 = Seq(30).toDF("a")
Expand All @@ -396,7 +391,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false)
checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)))

// Check table function with nested functions on df column
val df4 = session.tableFunction(tableFunction, abs(ceil(sourceDF("a"))), lit(2))
Expand All @@ -405,7 +400,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df4, Seq(Row(10), Row(11), Row(20), Row(21)), false)
checkAnswer(df4, Seq(Row(10), Row(11), Row(20), Row(21)))

// Check result df column filtering with duplicate column names
val sourceDF3 = Seq(30).toDF("C1")
Expand All @@ -419,7 +414,7 @@ class UDTFSuite extends TestData {
"""root
| |--C1: Long (nullable = true)
|""".stripMargin)
checkAnswer(df, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false)
checkAnswer(df, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)))
}
} finally {
runQuery(s"drop function if exists $funcName(NUMBER, NUMBER)", session)
Expand Down Expand Up @@ -1866,25 +1861,21 @@ class UDTFSuite extends TestData {
val tableFunction3 = session.udtf.registerTemporary(new ReturnManyColumns(3))
val df3 = session.tableFunction(tableFunction3, lit(10))
assert(df3.schema.length == 3)
checkAnswer(df3, Seq(Row(11, 12, 13), Row(1, 2, 3)), false)
checkAnswer(df3, Seq(Row(11, 12, 13), Row(1, 2, 3)))

// Test UDTF return 100 columns
val tableFunction100 = session.udtf.registerTemporary(new ReturnManyColumns(100))
val df100 = session.tableFunction(tableFunction100, lit(20))
assert(df100.schema.length == 100)
checkAnswer(
df100,
Seq(Row.fromArray((21 to 120).toArray), Row.fromArray((1 to 100).toArray)),
false)
checkAnswer(df100, Seq(Row.fromArray((21 to 120).toArray), Row.fromArray((1 to 100).toArray)))

// Test UDTF return 200 columns
val tableFunction200 = session.udtf.registerTemporary(new ReturnManyColumns(200))
val df200 = session.tableFunction(tableFunction200, lit(100))
assert(df200.schema.length == 200)
checkAnswer(
df200,
Seq(Row.fromArray((101 to 300).toArray), Row.fromArray((1 to 200).toArray)),
false)
Seq(Row.fromArray((101 to 300).toArray), Row.fromArray((1 to 200).toArray)))
}

test("test output type: basic types") {
Expand Down

0 comments on commit 3eb4e71

Please sign in to comment.