From 99792d49de4ad28bdb4e0831a3d4e9f37f9c26df Mon Sep 17 00:00:00 2001 From: Bing Li Date: Wed, 21 Aug 2024 15:27:55 -0700 Subject: [PATCH] update test format --- .../workflows/precommit-code-verification.yml | 24 + build.sbt | 28 +- scripts/format_checker.sh | 22 + .../com/snowflake/snowpark/functions.scala | 23 +- .../com/snowflake/snowpark/TestRunner.java | 41 - .../code_verification/ClassUtils.scala | 6 +- .../code_verification/JavaScalaAPISuite.scala | 207 +--- .../code_verification/PomSuite.scala | 14 +- .../scala/com/snowflake/perf/PerfBase.scala | 10 +- .../snowflake/snowpark/APIInternalSuite.scala | 158 +-- .../snowpark/DropTempObjectsSuite.scala | 23 +- .../snowpark/ErrorMessageSuite.scala | 572 +++-------- .../snowpark/ExpressionAndPlanNodeSuite.scala | 195 ++-- .../com/snowflake/snowpark/JavaAPISuite.scala | 124 +-- .../snowflake/snowpark/MethodChainSuite.scala | 30 +- .../snowpark/NewColumnReferenceSuite.scala | 49 +- .../snowpark/OpenTelemetryEnabled.scala | 12 +- .../snowflake/snowpark/ParameterSuite.scala | 3 +- .../com/snowflake/snowpark/ReplSuite.scala | 6 +- .../snowpark/ResultAttributesSuite.scala | 16 +- .../com/snowflake/snowpark/SFTestUtils.scala | 3 +- .../com/snowflake/snowpark/SNTestBase.scala | 26 +- .../snowpark/ServerConnectionSuite.scala | 19 +- .../snowflake/snowpark/SimplifierSuite.scala | 14 +- .../snowpark/SnowflakePlanSuite.scala | 24 +- .../SnowparkSFConnectionHandlerSuite.scala | 3 +- .../snowpark/StagedFileReaderSuite.scala | 9 +- .../com/snowflake/snowpark/TestData.scala | 97 +- .../com/snowflake/snowpark/TestUtils.scala | 59 +- .../snowpark/UDFClasspathSuite.scala | 6 +- .../snowflake/snowpark/UDFInternalSuite.scala | 19 +- .../snowpark/UDFRegistrationSuite.scala | 18 +- .../snowpark/UDTFInternalSuite.scala | 18 +- .../com/snowflake/snowpark/UtilsSuite.scala | 90 +- .../snowpark_test/AsyncJobSuite.scala | 127 +-- .../snowflake/snowpark_test/ColumnSuite.scala | 206 ++-- .../snowpark_test/ComplexDataFrameSuite.scala | 21 +- .../CopyableDataFrameSuite.scala | 317 ++---- .../DataFrameAggregateSuite.scala | 271 ++--- .../snowpark_test/DataFrameAliasSuite.scala | 34 +- .../snowpark_test/DataFrameJoinSuite.scala | 159 +-- .../DataFrameNonStoredProcSuite.scala | 24 +- .../snowpark_test/DataFrameReaderSuite.scala | 131 +-- .../DataFrameSetOperationsSuite.scala | 24 +- .../snowpark_test/DataFrameSuite.scala | 471 +++------ .../snowpark_test/DataFrameWriterSuite.scala | 119 +-- .../snowpark_test/DataTypeSuite.scala | 87 +- .../snowpark_test/FileOperationSuite.scala | 71 +- .../snowpark_test/FunctionSuite.scala | 924 ++++++------------ .../snowpark_test/IndependentClassSuite.scala | 32 +- .../snowpark_test/JavaUtilsSuite.scala | 36 +- .../snowpark_test/LargeDataFrameSuite.scala | 61 +- .../snowpark_test/LiteralSuite.scala | 42 +- .../snowpark_test/OpenTelemetrySuite.scala | 15 +- .../snowpark_test/PermanentUDFSuite.scala | 547 ++++------- .../snowpark_test/RequestTimeoutSuite.scala | 3 +- .../snowpark_test/ResultSchemaSuite.scala | 22 +- .../snowflake/snowpark_test/RowSuite.scala | 18 +- .../snowpark_test/ScalaGeographySuite.scala | 6 +- .../snowpark_test/ScalaVariantSuite.scala | 13 +- .../snowpark_test/SessionSuite.scala | 23 +- .../snowflake/snowpark_test/SqlSuite.scala | 6 +- .../snowpark_test/StoredProcedureSuite.scala | 477 +++------ .../snowpark_test/TableFunctionSuite.scala | 186 ++-- .../snowflake/snowpark_test/TableSuite.scala | 37 +- .../snowflake/snowpark_test/UDFSuite.scala | 776 ++++++--------- .../snowflake/snowpark_test/UDTFSuite.scala | 349 +++---- .../snowpark_test/UdxOpenTelemetrySuite.scala | 12 +- .../snowpark_test/UpdatableSuite.scala | 63 +- .../snowflake/snowpark_test/ViewSuite.scala | 9 +- .../snowpark_test/WindowFramesSuite.scala | 77 +- .../snowpark_test/WindowSpecSuite.scala | 147 +-- 72 files changed, 2581 insertions(+), 5330 deletions(-) create mode 100644 .github/workflows/precommit-code-verification.yml delete mode 100644 src/test/java/com/snowflake/snowpark/TestRunner.java diff --git a/.github/workflows/precommit-code-verification.yml b/.github/workflows/precommit-code-verification.yml new file mode 100644 index 00000000..c474b087 --- /dev/null +++ b/.github/workflows/precommit-code-verification.yml @@ -0,0 +1,24 @@ +name: precommit test - code verification +on: + push: + branches: [ main ] + pull_request: + branches: '**' + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup JDK + uses: actions/setup-java@v3 + with: + distribution: temurin + java-version: 8 + - name: Decrypt profile.properties + run: .github/scripts/decrypt_profile.sh + env: + PROFILE_PASSWORD: ${{ secrets.PROFILE_PASSWORD }} + - name: Run test + run: sbt CodeVerificationTests:test \ No newline at end of file diff --git a/build.sbt b/build.sbt index 3df5bd88..e7838a94 100644 --- a/build.sbt +++ b/build.sbt @@ -5,6 +5,7 @@ val openTelemetryVersion = "1.41.0" val slf4jVersion = "2.0.4" lazy val root = (project in file(".")) + .configs(CodeVerificationTests) .settings( name := "snowpark", version := "1.15.0-SNAPSHOT", @@ -31,7 +32,7 @@ lazy val root = (project in file(".")) "commons-codec" % "commons-codec" % "1.17.0", "io.opentelemetry" % "opentelemetry-api" % openTelemetryVersion, "net.snowflake" % "snowflake-jdbc" % "3.17.0", - "com.github.vertical-blank" % "sql-formatter" % "2.0.5", + "com.github.vertical-blank" % "sql-formatter" % "1.0.2", "com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion, "com.fasterxml.jackson.core" % "jackson-core" % jacksonVersion, "com.fasterxml.jackson.core" % "jackson-annotations" % jacksonVersion, @@ -48,8 +49,12 @@ lazy val root = (project in file(".")) javafmtOnCompile := true, Test / testOptions := Seq(Tests.Argument(TestFrameworks.JUnit, "-a")), // Test / crossPaths := false, - Test / fork := true, - Test / javaOptions ++= Seq("-Xms1024M", "-Xmx4096M"), + Test / fork := false, +// Test / javaOptions ++= Seq("-Xms1024M", "-Xmx4096M"), + inConfig(CodeVerificationTests)(Defaults.testTasks), + CodeVerificationTests / testOptions += Tests.Filter(isCodeVerification), + // default test + Test / testOptions += Tests.Filter(isRemainingTest), // Release settings // usePgpKeyHex(Properties.envOrElse("GPG_SIGNATURE", "12345")), Global / pgpPassphrase := Properties.envOrNone("GPG_KEY_PASSPHRASE").map(_.toCharArray), @@ -80,3 +85,20 @@ lazy val root = (project in file(".")) } ) ) + +// Test Groups +// Code Verification +def isCodeVerification(name: String): Boolean = { + name.startsWith("com.snowflake.code_verification") +} +lazy val CodeVerificationTests = config("CodeVerificationTests") extend Test + + +// Java API Tests +// Java UDx Tests +// Scala UDx Tests +// FIPS Tests + +// other Tests +def isRemainingTest(name: String): Boolean = name.endsWith("JavaAPISuite") +// ! isCodeVerification(name) \ No newline at end of file diff --git a/scripts/format_checker.sh b/scripts/format_checker.sh index 74dc8d61..8536f7c0 100755 --- a/scripts/format_checker.sh +++ b/scripts/format_checker.sh @@ -1,5 +1,6 @@ #!/bin/bash -ex +# format src sbt clean compile if [ -z "$(git status --porcelain)" ]; then @@ -9,3 +10,24 @@ else echo "Run 'sbt clean compile' to reformat" exit 1 fi + +# format scala test +sbt test:scalafmt +if [ -z "$(git status --porcelain)" ]; then + echo "Scala Test Code Format Check: Passed!" +else + echo "Scala Test Code Format Check: Failed!" + echo "Run 'sbt test:scalafmt' to reformat" + exit 1 +fi + +# format java test +sbt test:javafmt +if [ -z "$(git status --porcelain)" ]; then + echo "Scala Test Code Format Check: Passed!" +else + echo "Scala Test Code Format Check: Failed!" + echo "Run 'sbt test:scalafmt' to reformat" + exit 1 +fi + diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 0090fb96..6281ec48 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -2858,8 +2858,8 @@ object functions { * specified string column. If the regex did not match, or the specified group did not match, an * empty string is returned. Example: from snowflake.snowpark.functions import regexp_extract * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"]) - * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show() --------- - * \|"RES" | --------- + * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show() --------- \|"RES" + * \| --------- * | 20 | * |:---| * | 40 | @@ -2896,9 +2896,9 @@ object functions { * Args: col: The column to evaluate its sign Example:: >>> df = * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>> * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"), - * sign("c").alias("c_sign")).show() ---------------------------------- - * \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ---------------------------------- - * \|-1 |1 |0 | ---------------------------------- + * sign("c").alias("c_sign")).show() ---------------------------------- \|"A_SIGN" |"B_SIGN" + * \|"C_SIGN" | ---------------------------------- \|-1 |1 |0 | + * ---------------------------------- * @since 1.14.0 * @param e * Column to calculate the sign. @@ -2918,9 +2918,9 @@ object functions { * Args: col: The column to evaluate its sign Example:: >>> df = * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>> * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"), - * sign("c").alias("c_sign")).show() ---------------------------------- - * \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ---------------------------------- - * \|-1 |1 |0 | ---------------------------------- + * sign("c").alias("c_sign")).show() ---------------------------------- \|"A_SIGN" |"B_SIGN" + * \|"C_SIGN" | ---------------------------------- \|-1 |1 |0 | + * ---------------------------------- * @since 1.14.0 * @param e * Column to calculate the sign. @@ -2973,8 +2973,8 @@ object functions { /** Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is * returned. Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) - * >>> df.select(array_agg("a", True).alias("result")).show() ------------ - * \|"RESULT" | ------------ + * >>> df.select(array_agg("a", True).alias("result")).show() ------------ \|"RESULT" | + * ------------ * | [ | * |:---| * | 1, | @@ -2994,8 +2994,7 @@ object functions { * returned. * * Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) >>> - * df.select(array_agg("a", True).alias("result")).show() ------------ - * \|"RESULT" | ------------ + * df.select(array_agg("a", True).alias("result")).show() ------------ \|"RESULT" | ------------ * | [ | * |:---| * | 1, | diff --git a/src/test/java/com/snowflake/snowpark/TestRunner.java b/src/test/java/com/snowflake/snowpark/TestRunner.java deleted file mode 100644 index a2944453..00000000 --- a/src/test/java/com/snowflake/snowpark/TestRunner.java +++ /dev/null @@ -1,41 +0,0 @@ -package com.snowflake.snowpark; - -import org.junit.internal.TextListener; -import org.junit.runner.JUnitCore; -import org.junit.runner.Result; -import org.junit.runner.notification.Failure; - -public class TestRunner { - private static final JUnitCore junit = init(); - - private static JUnitCore init() { - // turn on assertion, otherwise Java assert doesn't work. - TestRunner.class.getClassLoader().setDefaultAssertionStatus(true); - JUnitCore jCore = new JUnitCore(); - jCore.addListener(new TextListener(System.out)); - return jCore; - } - - public static void run(Class clazz) { - Result result = junit.run(clazz); - - StringBuilder failures = new StringBuilder(); - for (Failure f : result.getFailures()) { - failures.append(f.getTrimmedTrace()); - } - - System.out.println( - "Finished. Result: Failures: " - + result.getFailureCount() - + " Ignored: " - + result.getIgnoreCount() - + " Tests run: " - + result.getRunCount() - + ". Time: " - + result.getRunTime() - + "ms."); - if (result.getFailureCount() > 0) { - throw new RuntimeException("Failures:\n" + failures.toString()); - } - } -} diff --git a/src/test/scala/com/snowflake/code_verification/ClassUtils.scala b/src/test/scala/com/snowflake/code_verification/ClassUtils.scala index 30cda4dd..bd20025f 100644 --- a/src/test/scala/com/snowflake/code_verification/ClassUtils.scala +++ b/src/test/scala/com/snowflake/code_verification/ClassUtils.scala @@ -36,8 +36,7 @@ object ClassUtils extends Logging { class2: Class[B], class1Only: Set[String] = Set.empty, class2Only: Set[String] = Set.empty, - class1To2NameMap: Map[String, String] = Map.empty - ): Boolean = { + class1To2NameMap: Map[String, String] = Map.empty): Boolean = { val nameList1 = getAllPublicFunctionNames(class1) val nameList2 = mutable.Set[String](getAllPublicFunctionNames(class2): _*) var missed = false @@ -63,8 +62,7 @@ object ClassUtils extends Logging { list2Cache.remove(name) } else { logError(s"${class1.getName} misses function $name") - } - ) + }) !missed && list2Cache.isEmpty } } diff --git a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala index 4252243a..f71a8182 100644 --- a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala +++ b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala @@ -18,8 +18,7 @@ class JavaScalaAPISuite extends FunSuite { "productArity", "unapply", "tupled", - "curried" - ) + "curried") // used to get list of Scala Seq functions class FakeSeq extends Seq[String] { @@ -44,11 +43,8 @@ class JavaScalaAPISuite extends FunSuite { "column", // Java API has "col", Scala API has both "col" and "column" "callBuiltin", // Java API has "callUDF", Scala API has both "callBuiltin" and "callUDF" "typedLit", // Scala API only, Java API has lit - "builtin" - ), - class1To2NameMap = Map("chr" -> "char") - ) - ) + "builtin"), + class1To2NameMap = Map("chr" -> "char"))) } test("AsyncJob") { @@ -72,9 +68,7 @@ class JavaScalaAPISuite extends FunSuite { class2Only = Set( "name" // Java API has "alias" ) ++ scalaCaseClassFunctions, - class1To2NameMap = Map("subField" -> "apply") - ) - ) + class1To2NameMap = Map("subField" -> "apply"))) } test("CaseExpr") { @@ -85,9 +79,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaCaseExpr], classOf[ScalaCaseExpr], // Java API has "otherwise", Scala API has both "otherwise" and "else" - class2Only = Set("else") - ) - ) + class2Only = Set("else"))) } test("DataFrame") { @@ -102,10 +94,7 @@ class JavaScalaAPISuite extends FunSuite { "getUnaliased", "methodChainCache", "buildMethodChain", - "generatePrefix" - ) ++ scalaCaseClassFunctions - ) - ) + "generatePrefix") ++ scalaCaseClassFunctions)) } test("CopyableDataFrame") { @@ -116,9 +105,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaCopyableDataFrame], classOf[ScalaCopyableDataFrame], // package private - class2Only = Set("getCopyDataFrame") - ) - ) + class2Only = Set("getCopyDataFrame"))) } test("CopyableDataFrameAsyncActor") { @@ -129,9 +116,7 @@ class JavaScalaAPISuite extends FunSuite { assert( ClassUtils.containsSameFunctionNames( classOf[JavaCopyableDataFrameAsyncActor], - classOf[ScalaCopyableDataFrameAsyncActor] - ) - ) + classOf[ScalaCopyableDataFrameAsyncActor])) } test("DataFrameAsyncActor") { @@ -140,9 +125,7 @@ class JavaScalaAPISuite extends FunSuite { assert( ClassUtils.containsSameFunctionNames( classOf[JavaDataFrameAsyncActor], - classOf[ScalaDataFrameAsyncActor] - ) - ) + classOf[ScalaDataFrameAsyncActor])) } test("DataFrameNaFunctions") { @@ -151,9 +134,7 @@ class JavaScalaAPISuite extends FunSuite { assert( ClassUtils.containsSameFunctionNames( classOf[JavaDataFrameNaFunctions], - classOf[ScalaDataFrameNaFunctions] - ) - ) + classOf[ScalaDataFrameNaFunctions])) } test("DataFrameReader") { @@ -161,8 +142,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.{DataFrameReader => ScalaDataFrameReader} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaDataFrameReader], classOf[ScalaDataFrameReader]) - ) + .containsSameFunctionNames(classOf[JavaDataFrameReader], classOf[ScalaDataFrameReader])) } test("DataFrameStatFunctions") { @@ -171,9 +151,7 @@ class JavaScalaAPISuite extends FunSuite { assert( ClassUtils.containsSameFunctionNames( classOf[JavaDataFrameStatFunctions], - classOf[ScalaDataFrameStatFunctions] - ) - ) + classOf[ScalaDataFrameStatFunctions])) } test("DataFrameWriter") { @@ -184,9 +162,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaDataFrameWriter], classOf[ScalaDataFrameWriter], // package private - class2Only = Set("getWritePlan") - ) - ) + class2Only = Set("getWritePlan"))) } test("DataFrameWriterAsyncActor") { @@ -195,9 +171,7 @@ class JavaScalaAPISuite extends FunSuite { assert( ClassUtils.containsSameFunctionNames( classOf[JavaDataFrameWriterAsyncActor], - classOf[ScalaDataFrameWriterAsyncActor] - ) - ) + classOf[ScalaDataFrameWriterAsyncActor])) } test("DeleteResult") { @@ -209,9 +183,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[ScalaDeleteResult], class1Only = Set(), class2Only = scalaCaseClassFunctions, - class1To2NameMap = Map("getRowsDeleted" -> "rowsDeleted") - ) - ) + class1To2NameMap = Map("getRowsDeleted" -> "rowsDeleted"))) } test("FileOperation") { @@ -219,8 +191,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.{FileOperation => ScalaFileOperation} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaFileOperation], classOf[ScalaFileOperation]) - ) + .containsSameFunctionNames(classOf[JavaFileOperation], classOf[ScalaFileOperation])) } test("GetResult") { @@ -237,10 +208,7 @@ class JavaScalaAPISuite extends FunSuite { "getStatus" -> "status", "getSizeBytes" -> "sizeBytes", "getMessage" -> "message", - "getFileName" -> "fileName" - ) - ) - ) + "getFileName" -> "fileName"))) } test("GroupingSets") { @@ -252,9 +220,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[ScalaGroupingSets], class1Only = Set(), class2Only = Set("sets", "toExpression") ++ scalaCaseClassFunctions, - class1To2NameMap = Map("create" -> "apply") - ) - ) + class1To2NameMap = Map("create" -> "apply"))) } test("HasCachedResult") { @@ -262,8 +228,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.{HasCachedResult => ScalaHasCachedResult} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaHasCachedResult], classOf[ScalaHasCachedResult]) - ) + .containsSameFunctionNames(classOf[JavaHasCachedResult], classOf[ScalaHasCachedResult])) } test("MatchedClauseBuilder") { @@ -274,17 +239,14 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaMatchedClauseBuilder], classOf[ScalaMatchedClauseBuilder], // scala api has update[T] - class1Only = Set("updateColumn") - ) - ) + class1Only = Set("updateColumn"))) } test("MergeBuilder") { import com.snowflake.snowpark_java.{MergeBuilder => JavaMergeBuilder} import com.snowflake.snowpark.{MergeBuilder => ScalaMergeBuilder} assert( - ClassUtils.containsSameFunctionNames(classOf[JavaMergeBuilder], classOf[ScalaMergeBuilder]) - ) + ClassUtils.containsSameFunctionNames(classOf[JavaMergeBuilder], classOf[ScalaMergeBuilder])) } test("MergeResult") { @@ -299,10 +261,7 @@ class JavaScalaAPISuite extends FunSuite { class1To2NameMap = Map( "getRowsInserted" -> "rowsInserted", "getRowsUpdated" -> "rowsUpdated", - "getRowsDeleted" -> "rowsDeleted" - ) - ) - ) + "getRowsDeleted" -> "rowsDeleted"))) } test("NotMatchedClauseBuilder") { @@ -312,9 +271,7 @@ class JavaScalaAPISuite extends FunSuite { ClassUtils.containsSameFunctionNames( classOf[JavaNotMatchedClauseBuilder], classOf[ScalaNotMatchedClauseBuilder], - class1Only = Set("insertRow") - ) - ) + class1Only = Set("insertRow"))) } test("PutResult") { @@ -335,10 +292,7 @@ class JavaScalaAPISuite extends FunSuite { "getSourceFileName" -> "sourceFileName", "getSourceCompression" -> "sourceCompression", "getTargetSizeBytes" -> "targetSizeBytes", - "getSourceSizeBytes" -> "sourceSizeBytes" - ) - ) - ) + "getSourceSizeBytes" -> "sourceSizeBytes"))) } test("RelationalGroupedDataFrame") { @@ -349,9 +303,7 @@ class JavaScalaAPISuite extends FunSuite { assert( ClassUtils.containsSameFunctionNames( classOf[JavaRelationalGroupedDataFrame], - classOf[ScalaRelationalGroupedDataFrame] - ) - ) + classOf[ScalaRelationalGroupedDataFrame])) } test("Row") { @@ -371,10 +323,7 @@ class JavaScalaAPISuite extends FunSuite { "toList" -> "toSeq", "create" -> "apply", "getListOfVariant" -> "getSeqOfVariant", - "getList" -> "getSeq" - ) - ) - ) + "getList" -> "getSeq"))) } // Java SaveMode is an Enum, @@ -392,10 +341,7 @@ class JavaScalaAPISuite extends FunSuite { "storedProcedure", // todo in snow-683655 "sproc", // todo in snow-683653 "getDependenciesAsJavaSet", // Java API renamed to "getDependencies" - "implicits" - ) - ) - ) + "implicits"))) } test("SessionBuilder") { @@ -403,8 +349,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.Session.{SessionBuilder => ScalaSessionBuilder} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaSessionBuilder], classOf[ScalaSessionBuilder]) - ) + .containsSameFunctionNames(classOf[JavaSessionBuilder], classOf[ScalaSessionBuilder])) } test("TableFunction") { @@ -415,9 +360,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaTableFunction], classOf[ScalaTableFunction], class1Only = Set("call"), // `call` in Scala is `apply` - class2Only = Set("funcName") ++ scalaCaseClassFunctions - ) - ) + class2Only = Set("funcName") ++ scalaCaseClassFunctions)) } test("TableFunctions") { @@ -425,8 +368,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.{tableFunctions => ScalaTableFunctions} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaTableFunctions], ScalaTableFunctions.getClass) - ) + .containsSameFunctionNames(classOf[JavaTableFunctions], ScalaTableFunctions.getClass)) } test("TypedAsyncJob") { @@ -434,8 +376,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.{TypedAsyncJob => ScalaTypedAsyncJob} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaTypedAsyncJob[_]], classOf[ScalaTypedAsyncJob[_]]) - ) + .containsSameFunctionNames(classOf[JavaTypedAsyncJob[_]], classOf[ScalaTypedAsyncJob[_]])) } test("UDFRegistration") { @@ -443,8 +384,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.{UDFRegistration => ScalaUDFRegistration} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaUDFRegistration], classOf[ScalaUDFRegistration]) - ) + .containsSameFunctionNames(classOf[JavaUDFRegistration], classOf[ScalaUDFRegistration])) } test("Updatable") { @@ -454,9 +394,7 @@ class JavaScalaAPISuite extends FunSuite { ClassUtils.containsSameFunctionNames( classOf[JavaUpdatable], classOf[ScalaUpdatable], - class1Only = Set("updateColumn") - ) - ) + class1Only = Set("updateColumn"))) } test("UpdatableAsyncActor") { @@ -466,9 +404,7 @@ class JavaScalaAPISuite extends FunSuite { ClassUtils.containsSameFunctionNames( classOf[JavaUpdatableAsyncActor], classOf[ScalaUpdatableAsyncActor], - class1Only = Set("updateColumn") - ) - ) + class1Only = Set("updateColumn"))) } test("UpdateResult") { @@ -482,10 +418,7 @@ class JavaScalaAPISuite extends FunSuite { class2Only = scalaCaseClassFunctions, class1To2NameMap = Map( "getRowsUpdated" -> "rowsUpdated", - "getMultiJoinedRowsUpdated" -> "multiJoinedRowsUpdated" - ) - ) - ) + "getMultiJoinedRowsUpdated" -> "multiJoinedRowsUpdated"))) } test("UserDefinedFunction") { @@ -497,9 +430,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[ScalaUserDefinedFunction], class1Only = Set(), class2Only = Set("f", "returnType", "name", "inputTypes", "withName") ++ - scalaCaseClassFunctions - ) - ) + scalaCaseClassFunctions)) } test("Windows") { @@ -524,9 +455,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[ScalaArrayType], class1Only = Set(), class2Only = scalaCaseClassFunctions, - class1To2NameMap = Map("getElementType" -> "elementType") - ) - ) + class1To2NameMap = Map("getElementType" -> "elementType"))) } test("BinaryType") { @@ -537,9 +466,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaBinaryType], ScalaBinaryType.getClass, class1Only = Set(), - class2Only = scalaCaseClassFunctions - ) - ) + class2Only = scalaCaseClassFunctions)) } test("BooleanType") { @@ -550,9 +477,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaBooleanType], ScalaBooleanType.getClass, class1Only = Set(), - class2Only = scalaCaseClassFunctions - ) - ) + class2Only = scalaCaseClassFunctions)) } test("ByteType") { @@ -563,9 +488,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaByteType], ScalaByteType.getClass, class1Only = Set(), - class2Only = scalaCaseClassFunctions - ) - ) + class2Only = scalaCaseClassFunctions)) } test("ColumnIdentifier") { @@ -576,9 +499,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaColumnIdentifier], classOf[ScalaColumnIdentifier], class1Only = Set(), - class2Only = scalaCaseClassFunctions - ) - ) + class2Only = scalaCaseClassFunctions)) } test("DateType") { @@ -589,9 +510,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaDateType], ScalaDateType.getClass, class1Only = Set(), - class2Only = scalaCaseClassFunctions - ) - ) + class2Only = scalaCaseClassFunctions)) } test("DecimalType") { @@ -603,9 +522,7 @@ class JavaScalaAPISuite extends FunSuite { ScalaDecimalType.getClass, class1Only = Set("getPrecision", "getScale"), class2Only = Set("MAX_SCALE", "MAX_PRECISION") ++ - scalaCaseClassFunctions - ) - ) + scalaCaseClassFunctions)) } test("DoubleType") { @@ -627,9 +544,7 @@ class JavaScalaAPISuite extends FunSuite { ClassUtils.containsSameFunctionNames( classOf[JavaGeography], classOf[ScalaGeograhy], - class2Only = Set("getString") - ) - ) + class2Only = Set("getString"))) } test("GeographyType") { @@ -637,8 +552,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.types.{GeographyType => ScalaGeograhyType} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaGeographyType], ScalaGeograhyType.getClass) - ) + .containsSameFunctionNames(classOf[JavaGeographyType], ScalaGeograhyType.getClass)) } test("Geometry") { @@ -652,16 +566,14 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.types.{GeometryType => ScalaGeometryType} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaGeometryType], ScalaGeometryType.getClass) - ) + .containsSameFunctionNames(classOf[JavaGeometryType], ScalaGeometryType.getClass)) } test("IntegerType") { import com.snowflake.snowpark_java.types.{IntegerType => JavaIntegerType} import com.snowflake.snowpark.types.{IntegerType => ScalaIntegerType} assert( - ClassUtils.containsSameFunctionNames(classOf[JavaIntegerType], ScalaIntegerType.getClass) - ) + ClassUtils.containsSameFunctionNames(classOf[JavaIntegerType], ScalaIntegerType.getClass)) } test("LongType") { @@ -678,9 +590,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaMapType], ScalaMapType.getClass, class1Only = Set("getValueType", "getKeyType"), - class2Only = Set("unapply") - ) - ) + class2Only = Set("unapply"))) } test("ShortType") { @@ -700,8 +610,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark.types.{TimestampType => ScalaTimestampType} assert( ClassUtils - .containsSameFunctionNames(classOf[JavaTimestampType], ScalaTimestampType.getClass) - ) + .containsSameFunctionNames(classOf[JavaTimestampType], ScalaTimestampType.getClass)) } test("TimeType") { @@ -714,8 +623,7 @@ class JavaScalaAPISuite extends FunSuite { import com.snowflake.snowpark_java.types.{VariantType => JavaVariantType} import com.snowflake.snowpark.types.{VariantType => ScalaVariantType} assert( - ClassUtils.containsSameFunctionNames(classOf[JavaVariantType], ScalaVariantType.getClass) - ) + ClassUtils.containsSameFunctionNames(classOf[JavaVariantType], ScalaVariantType.getClass)) } test("Variant") { @@ -727,9 +635,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[ScalaVariant], class1Only = Set(), class2Only = Set("dataType", "value"), - class1To2NameMap = Map("asBigInteger" -> "asBigInt", "asList" -> "asSeq") - ) - ) + class1To2NameMap = Map("asBigInteger" -> "asBigInt", "asList" -> "asSeq"))) } test("StructField") { @@ -740,9 +646,7 @@ class JavaScalaAPISuite extends FunSuite { classOf[JavaStructField], classOf[ScalaStructField], class1Only = Set(), - class2Only = Set("treeString") ++ scalaCaseClassFunctions - ) - ) + class2Only = Set("treeString") ++ scalaCaseClassFunctions)) } test("StructType") { @@ -756,13 +660,10 @@ class JavaScalaAPISuite extends FunSuite { // Java Iterable "forEach", "get", - "spliterator" - ), + "spliterator"), class2Only = Set("fields") ++ scalaSeqFunctions ++ scalaCaseClassFunctions, - class1To2NameMap = Map("create" -> "apply") - ) - ) + class1To2NameMap = Map("create" -> "apply"))) } } diff --git a/src/test/scala/com/snowflake/code_verification/PomSuite.scala b/src/test/scala/com/snowflake/code_verification/PomSuite.scala index adcd4750..b5b39f4d 100644 --- a/src/test/scala/com/snowflake/code_verification/PomSuite.scala +++ b/src/test/scala/com/snowflake/code_verification/PomSuite.scala @@ -12,23 +12,21 @@ class PomSuite extends FunSuite { private val fipsPomFileName = "fips-pom.xml" private val javaDocPomFileName = "java_doc.xml" - test("project versions should be updated together") { + // todo: should be replaced by SBT + ignore("project versions should be updated together") { assert( PomUtils.getProjectVersion(pomFileName) == - PomUtils.getProjectVersion(javaDocPomFileName) - ) + PomUtils.getProjectVersion(javaDocPomFileName)) assert( PomUtils.getProjectVersion(pomFileName) == - PomUtils.getProjectVersion(fipsPomFileName) - ) + PomUtils.getProjectVersion(fipsPomFileName)) assert( PomUtils .getProjectVersion(pomFileName) - .matches("\\d+\\.\\d+\\.\\d+(-SNAPSHOT)?") - ) + .matches("\\d+\\.\\d+\\.\\d+(-SNAPSHOT)?")) } - test("dependencies of pom and fips should be updated together") { + ignore("dependencies of pom and fips should be updated together") { val pomDependencies = PomUtils.getProductDependencies(pomFileName) val fipsDependencies = PomUtils.getProductDependencies(fipsPomFileName) diff --git a/src/test/scala/com/snowflake/perf/PerfBase.scala b/src/test/scala/com/snowflake/perf/PerfBase.scala index 6094f29e..bc0e46da 100644 --- a/src/test/scala/com/snowflake/perf/PerfBase.scala +++ b/src/test/scala/com/snowflake/perf/PerfBase.scala @@ -12,9 +12,9 @@ trait PerfBase extends SNTestBase { // to enable perf test use maven flag `-DargLine="-DPERF_TEST=true"` lazy val isPerfTest: Boolean = System.getProperty("PERF_TEST") match { - case null => false + case null => false case value if value.toLowerCase() == "true" => true - case _ => false + case _ => false } lazy val snowhouseProfile: String = "snowhouse.properties" @@ -34,8 +34,7 @@ trait PerfBase extends SNTestBase { Paths.get(resultFileName), data.getBytes, StandardOpenOption.CREATE, - StandardOpenOption.APPEND - ) + StandardOpenOption.APPEND) } override def beforeAll: Unit = { @@ -51,8 +50,7 @@ trait PerfBase extends SNTestBase { snowhouseSession.file.put(s"file://$resultFileName", s"@$tmpStageName") val fileSchema = StructType( StructField("TEST_NAME", StringType), - StructField("TIME_CONSUMPTION", DoubleType) - ) + StructField("TIME_CONSUMPTION", DoubleType)) snowhouseSession.read .schema(fileSchema) .csv(s"@$tmpStageName") diff --git a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala index 24164c6d..08fa7dec 100644 --- a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala @@ -29,8 +29,7 @@ import scala.util.Random class APIInternalSuite extends TestData { private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) val tmpStageName: String = randomStageName() @@ -91,8 +90,7 @@ class APIInternalSuite extends TestData { } assert(ex.errorCode.equals("0416")) assert( - ex.message.contains("Cannot close this session because it is used by stored procedure.") - ) + ex.message.contains("Cannot close this session because it is used by stored procedure.")) } finally { Session.resetGlobalStoredProcSession() } @@ -144,8 +142,8 @@ class APIInternalSuite extends TestData { throw new TestFailedException("Expect an exception") } catch { case _: SnowparkClientException => // expected - case _: SnowflakeSQLException => // expected - case e => throw e + case _: SnowflakeSQLException => // expected + case e => throw e } try { @@ -158,8 +156,8 @@ class APIInternalSuite extends TestData { throw new TestFailedException("Expect an exception") } catch { case _: SnowparkClientException => // expected - case _: SnowflakeSQLException => // expected - case e => throw e + case _: SnowflakeSQLException => // expected + case e => throw e } // int max is 2147483647 @@ -168,8 +166,7 @@ class APIInternalSuite extends TestData { .configFile(defaultProfile) .config(SnowparkRequestTimeoutInSeconds, "2147483648") .create - .requestTimeoutInSeconds - ) + .requestTimeoutInSeconds) // int min is -2147483648 assertThrows[SnowparkClientException]( @@ -177,16 +174,14 @@ class APIInternalSuite extends TestData { .configFile(defaultProfile) .config(SnowparkRequestTimeoutInSeconds, "-2147483649") .create - .requestTimeoutInSeconds - ) + .requestTimeoutInSeconds) assertThrows[SnowparkClientException]( Session.builder .configFile(defaultProfile) .config(SnowparkRequestTimeoutInSeconds, "abcd") .create - .requestTimeoutInSeconds - ) + .requestTimeoutInSeconds) } @@ -216,9 +211,7 @@ class APIInternalSuite extends TestData { assert(ex.errorCode.equals("0418")) assert( ex.message.contains( - "Invalid value negative_not_number for parameter snowpark_max_file_upload_retry_count." - ) - ) + "Invalid value negative_not_number for parameter snowpark_max_file_upload_retry_count.")) } test("cancel all", UnstableTest) { @@ -230,8 +223,7 @@ class APIInternalSuite extends TestData { random().as("b"), random().as("c"), random().as("d"), - random().as("e") - ) + random().as("e")) try { val q1 = testCanceled { @@ -273,8 +265,8 @@ class APIInternalSuite extends TestData { .plus(df.col("c")) .plus(df.col("d")) .plus(df.col("e")) - .as("result") - ).filter(df.col("result").gt(com.snowflake.snowpark_java.Functions.lit(0))) + .as("result")) + .filter(df.col("result").gt(com.snowflake.snowpark_java.Functions.lit(0))) .count() } @@ -334,8 +326,7 @@ class APIInternalSuite extends TestData { newSession.close() assert( Session.getActiveSession.isEmpty || - Session.getActiveSession.get != newSession - ) + Session.getActiveSession.get != newSession) // It's no problem to close the session multiple times newSession.close() @@ -368,11 +359,8 @@ class APIInternalSuite extends TestData { Literal(this) } assert( - ex.getMessage.contains( - "Cannot create a Literal for com.snowflake.snowpark." + - "APIInternalSuite(APIInternalSuite)" - ) - ) + ex.getMessage.contains("Cannot create a Literal for com.snowflake.snowpark." + + "APIInternalSuite(APIInternalSuite)")) } test("special BigDecimal literals") { @@ -398,8 +386,7 @@ class APIInternalSuite extends TestData { ||0.1 |0.00001 |100000 | ||0.1 |0.00001 |100000 | |------------------------------------------------------------------------------------ - |""".stripMargin - ) + |""".stripMargin) } test("show structured types mix") { @@ -427,8 +414,7 @@ class APIInternalSuite extends TestData { |-------------------------------------------------------------------------------------------------------------------------------------------------- ||NULL |1 |abc |{b:2,a:1} |{2:b,1:a} |Object(a:1,b:Array(1,2,3,4)) |[1,2,3] |[1.1,2.2,3.3] |{a1:Object(b:2),a2:Object(b:3)} | |-------------------------------------------------------------------------------------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } } @@ -456,8 +442,7 @@ class APIInternalSuite extends TestData { |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ||Object(a:"22",b:1) |Object(a:1,b:Array(1,2,3,4)) |Object(a:"1",b:Array(1,2,3,4),c:Map(1:"a")) |Object(a:Object(b:Object(c:1,a:10))) |[Object(a:1,b:2),Object(a:4,b:3)] |{a1:Object(b:2),a2:Object(b:3)} | |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } } @@ -486,8 +471,7 @@ class APIInternalSuite extends TestData { || | | | | | "b": 2 | || | | | | |} | |--------------------------------------------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } } @@ -521,8 +505,7 @@ class APIInternalSuite extends TestData { || | | | | | | | 2 |}] | | | | || | | | | | | |]] | | | | | |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } } @@ -538,8 +521,7 @@ class APIInternalSuite extends TestData { .last .sql .trim - .startsWith("SELECT *, 1 :: int AS \"NEWCOL\" FROM") - ) + .startsWith("SELECT *, 1 :: int AS \"NEWCOL\" FROM")) // use full name list if replacing existing column assert( @@ -549,8 +531,7 @@ class APIInternalSuite extends TestData { .last .sql .trim - .startsWith("SELECT \"B\", 1 :: int AS \"A\" FROM") - ) + .startsWith("SELECT \"B\", 1 :: int AS \"A\" FROM")) } test("union by name should not list all columns if not reorder") { @@ -616,8 +597,7 @@ class APIInternalSuite extends TestData { }, ParameterUtils.SnowparkUseScopedTempObjects, - "true" - ) + "true") } // functions @@ -633,18 +613,14 @@ class APIInternalSuite extends TestData { seq4(), seq4(false), seq8(), - seq8(false) - ) + seq8(false)) .snowflakePlan .queries assert(queries.size == 1) assert( - queries.head.sql.contains( - "SELECT seq1(0), seq1(1), seq2(0), seq2(1), seq4(0)," + - " seq4(1), seq8(0), seq8(1) FROM ( TABLE (GENERATOR(ROWCOUNT => 10)))" - ) - ) + queries.head.sql.contains("SELECT seq1(0), seq1(1), seq2(0), seq2(1), seq4(0)," + + " seq4(1), seq8(0), seq8(1) FROM ( TABLE (GENERATOR(ROWCOUNT => 10)))")) } // This test DataFrame can't be defined in TestData, @@ -654,8 +630,7 @@ class APIInternalSuite extends TestData { val queries = Seq( s"create temporary table $tableName1 (A int)", s"insert into $tableName1 values(1),(2),(3)", - s"select * from $tableName1" - ).map(Query(_)) + s"select * from $tableName1").map(Query(_)) val attrs = Seq(Attribute("A", IntegerType, nullable = true)) val postActions = Seq(Query(s"drop table if exists $tableName1", true)) val plan = @@ -665,8 +640,7 @@ class APIInternalSuite extends TestData { postActions, session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) new DataFrame(session, session.analyzer.resolve(plan), Seq()) } @@ -678,8 +652,7 @@ class APIInternalSuite extends TestData { val queries2 = Seq( s"create temporary table $tableName2 (A int, B string)", s"insert into $tableName2 values(1, 'a'), (2, 'b'), (3, 'c')", - s"select * from $tableName2" - ).map(Query(_)) + s"select * from $tableName2").map(Query(_)) val attrs2 = Seq(Attribute("A", IntegerType, nullable = true), Attribute("B", StringType, nullable = true)) val postActions2 = Seq(Query(s"drop table if exists $tableName2")) @@ -690,8 +663,7 @@ class APIInternalSuite extends TestData { postActions2, session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) new DataFrame(session, session.analyzer.resolve(plan2), Seq()) } @@ -709,31 +681,25 @@ class APIInternalSuite extends TestData { df2.select(Column("A"), Column("B"), col(df3)).show() checkAnswer( df2.select(Column("A"), Column("B"), col(df3)), - Seq(Row(1, "a", 1), Row(2, "b", 1), Row(3, "c", 1)) - ) + Seq(Row(1, "a", 1), Row(2, "b", 1), Row(3, "c", 1))) // SELECT 2 sub queries checkAnswer( df2.select(Column("A"), Column("B"), col(df3).as("s1"), col(df3).as("s2")), - Seq(Row(1, "a", 1, 1), Row(2, "b", 1, 1), Row(3, "c", 1, 1)) - ) + Seq(Row(1, "a", 1, 1), Row(2, "b", 1, 1), Row(3, "c", 1, 1))) // WHERE 2 sub queries checkAnswer( df2.filter(col("A") > col(df3) and col("A") < col(df1.groupBy().agg(max(col("A"))))), - Seq(Row(2, "b")) - ) + Seq(Row(2, "b"))) // SELECT 2 sub queries + WHERE 2 sub queries checkAnswer( df2 .select(col("A"), col("B"), col(df3).as("s1"), col(df1.filter(col("A") === 2)).as("s2")) - .filter( - col("A") > col(df1.groupBy().agg(mean(col("A")))) and - col("A") <= col(df1.groupBy().agg(max(col("A")))) - ), - Seq(Row(3, "c", 1, 2)) - ) + .filter(col("A") > col(df1.groupBy().agg(mean(col("A")))) and + col("A") <= col(df1.groupBy().agg(max(col("A"))))), + Seq(Row(3, "c", 1, 2))) } test("explain") { @@ -755,8 +721,7 @@ class APIInternalSuite extends TestData { schemaValueStatement(Seq(Attribute("NUM", LongType))), session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) val df = new DataFrame(session, plan, Seq()) df.explain() @@ -824,8 +789,7 @@ class APIInternalSuite extends TestData { "com.snowflake.snowpark.test", "TestClass", "snowpark_test_", - "test.jar" - ) + "test.jar") val fileName = TestUtils.getFileName(filePath) val miscCommands = Seq( @@ -840,8 +804,7 @@ class APIInternalSuite extends TestData { s"create temp view $objectName (string) as select current_version()", s"drop view $objectName", s"show tables", - s"drop stage $stageName" - ) + s"drop stage $stageName") // Misc commands with show() miscCommands.foreach(session.sql(_).show()) @@ -887,8 +850,7 @@ class APIInternalSuite extends TestData { .option("on_error", "continue") .option("COMPRESSION", "gzip") .csv(testFileOnStage), - Seq() - ) + Seq()) } // The constructor for AsyncJob/TypedAsyncJob is package private. @@ -943,8 +905,7 @@ class APIInternalSuite extends TestData { val rows = session .sql( s"select QUERY_TAG from table(information_schema.QUERY_HISTORY_BY_SESSION())" + - s" where QUERY_ID = '$queryId'" - ) + s" where QUERY_ID = '$queryId'") .collect() assert(rows.length == 1 && rows(0).getString(0).equals(testQueryTagValue)) } @@ -971,8 +932,7 @@ class APIInternalSuite extends TestData { } assert( getEx.getMessage.contains("Result for query") - && getEx.getMessage.contains("has expired") - ) + && getEx.getMessage.contains("has expired")) val dfPut = session.sql(s"put file://$path/$testFileCsv @$tmpStageName/testExecuteAndGetQueryId") @@ -982,8 +942,7 @@ class APIInternalSuite extends TestData { } assert( putEx.getMessage.contains("Result for query") - && putEx.getMessage.contains("has expired") - ) + && putEx.getMessage.contains("has expired")) } finally { TestUtils.removeFile(path, session) } @@ -1015,9 +974,7 @@ class APIInternalSuite extends TestData { StructField("boolean", BooleanType), StructField("binary", BinaryType), StructField("timestamp", TimestampType), - StructField("date", DateType) - ) - ) + StructField("date", DateType))) val timestamp: Long = 1606179541282L @@ -1037,13 +994,10 @@ class APIInternalSuite extends TestData { true, Array(1.toByte, 2.toByte), new Timestamp(timestamp - 100), - new Date(timestamp - 100) - ) - ) + new Date(timestamp - 100))) } largeData.append( - Row(1025, null, null, null, null, null, null, null, null, null, null, null, null) - ) + Row(1025, null, null, null, null, null, null, null, null, null, null, null, null)) val df = session.createDataFrame(largeData, schema) checkExecuteAndGetQueryId(df) @@ -1054,8 +1008,7 @@ class APIInternalSuite extends TestData { val rows = session .sql( s"select QUERY_TAG from table(information_schema.QUERY_HISTORY_BY_SESSION())" + - s" where QUERY_TAG = '$uniqueQueryTag'" - ) + s" where QUERY_TAG = '$uniqueQueryTag'") .collect() // The statement parameter is applied for the last query only, // even if there are 3 queries and 1 post actions for large local relation, @@ -1069,8 +1022,7 @@ class APIInternalSuite extends TestData { val rows = session .sql( s"select QUERY_TAG from table(information_schema.QUERY_HISTORY_BY_SESSION())" + - s" where QUERY_TAG = '$uniqueQueryTag'" - ) + s" where QUERY_TAG = '$uniqueQueryTag'") .collect() // The statement parameter is applied for the last query only, // even if there are 3 queries and 1 post actions for multipleQueriesDF1 @@ -1078,8 +1030,7 @@ class APIInternalSuite extends TestData { // case 2: test int/boolean parameter multipleQueriesDF1.executeAndGetQueryId( - Map("STATEMENT_TIMEOUT_IN_SECONDS" -> 100, "USE_CACHED_RESULT" -> false) - ) + Map("STATEMENT_TIMEOUT_IN_SECONDS" -> 100, "USE_CACHED_RESULT" -> false)) } test("VariantTypes.getType") { @@ -1105,8 +1056,7 @@ class APIInternalSuite extends TestData { val cachedResult2 = cachedResult.cacheResult() assert( cachedResult.snowflakePlan.queries.last.sql == - cachedResult2.snowflakePlan.queries.last.sql - ) + cachedResult2.snowflakePlan.queries.last.sql) checkAnswer(cachedResult2, expected) } @@ -1120,8 +1070,7 @@ class APIInternalSuite extends TestData { assert( plan1.summarize == "Union(Filter(Project(Project(SnowflakeValues())))" + - ",Project(Project(Project(SnowflakeValues()))))" - ) + ",Project(Project(Project(SnowflakeValues()))))") } test("DataFrame toDF should not generate useless project") { @@ -1130,8 +1079,7 @@ class APIInternalSuite extends TestData { val result1 = df.toDF("b", "a") assert( result1.snowflakePlan.queries.last - .countString("SELECT \"A\" AS \"B\", \"B\" AS \"A\" FROM") == 1 - ) + .countString("SELECT \"A\" AS \"B\", \"B\" AS \"A\" FROM") == 1) val result2 = df.toDF("a", "B") assert(result2.eq(df)) assert(result2.snowflakePlan.queries.last.countString("\"A\" AS \"A\"") == 0) diff --git a/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala b/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala index 9b15619f..f40ac2ef 100644 --- a/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala @@ -12,8 +12,7 @@ class DropTempObjectsSuite extends SNTestBase { val randomSchema: String = randomName() val tmpStageName: String = randomStageName() private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) override def beforeAll(): Unit = { super.beforeAll() @@ -60,8 +59,7 @@ class DropTempObjectsSuite extends SNTestBase { }, ParameterUtils.SnowparkUseScopedTempObjects, - "true" - ) + "true") } test("test session dropAllTempObjects with scoped temp object turned off") { @@ -91,10 +89,7 @@ class DropTempObjectsSuite extends SNTestBase { TempObjectType.Stage, TempObjectType.Table, TempObjectType.FileFormat, - TempObjectType.Function - ) - ) - ) + TempObjectType.Function))) dropMap.keys.foreach(k => { // Make sure name is fully qualified assert(k.split("\\.")(2).startsWith("SNOWPARK_TEMP_")) @@ -103,28 +98,24 @@ class DropTempObjectsSuite extends SNTestBase { session.dropAllTempObjects() }, ParameterUtils.SnowparkUseScopedTempObjects, - "false" - ) + "false") } test("Test recordTempObjectIfNecessary") { session.recordTempObjectIfNecessary( TempObjectType.Table, "db.schema.tempName1", - TempType.Temporary - ) + TempType.Temporary) assertTrue(session.getTempObjectMap.contains("db.schema.tempName1")) session.recordTempObjectIfNecessary( TempObjectType.Table, "db.schema.tempName2", - TempType.ScopedTemporary - ) + TempType.ScopedTemporary) assertFalse(session.getTempObjectMap.contains("db.schema.tempName2")) session.recordTempObjectIfNecessary( TempObjectType.Table, "db.schema.tempName3", - TempType.Permanent - ) + TempType.Permanent) assertFalse(session.getTempObjectMap.contains("db.schema.tempName3")) } } diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 50a10fa7..b2dedb44 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -14,11 +14,8 @@ class ErrorMessageSuite extends FunSuite { val ex = ErrorMessage.INTERNAL_TEST_MESSAGE("my message: '%d $'") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0010"))) assert( - ex.message.startsWith( - "Error Code: 0010, Error message: " + - "internal test message: my message: '%d $'" - ) - ) + ex.message.startsWith("Error Code: 0010, Error message: " + + "internal test message: my message: '%d $'")) } test("DF_CANNOT_DROP_COLUMN_NAME") { @@ -28,31 +25,23 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0100, Error message: " + "Unable to drop the column col. You must specify " + - "the column by name (e.g. df.drop(col(\"a\")))." - ) - ) + "the column by name (e.g. df.drop(col(\"a\"))).")) } test("DF_SORT_NEED_AT_LEAST_ONE_EXPR") { val ex = ErrorMessage.DF_SORT_NEED_AT_LEAST_ONE_EXPR() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0101"))) assert( - ex.message.startsWith( - "Error Code: 0101, Error message: " + - "For sort(), you must specify at least one sort expression." - ) - ) + ex.message.startsWith("Error Code: 0101, Error message: " + + "For sort(), you must specify at least one sort expression.")) } test("DF_CANNOT_DROP_ALL_COLUMNS") { val ex = ErrorMessage.DF_CANNOT_DROP_ALL_COLUMNS() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0102"))) assert( - ex.message.startsWith( - "Error Code: 0102, Error message: " + - s"Cannot drop all columns" - ) - ) + ex.message.startsWith("Error Code: 0102, Error message: " + + s"Cannot drop all columns")) } test("DF_CANNOT_RESOLVE_COLUMN_NAME_AMONG") { @@ -62,9 +51,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0103, Error message: " + "Cannot combine the DataFrames by column names. " + - """The column "c1" is not a column in the other DataFrame (a, b, c).""" - ) - ) + """The column "c1" is not a column in the other DataFrame (a, b, c).""")) } test("DF_SELF_JOIN_NOT_SUPPORTED") { @@ -75,31 +62,23 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0104, Error message: " + "You cannot join a DataFrame with itself because the column references cannot " + "be resolved correctly. Instead, call clone() to create a copy of the DataFrame," + - " and join the DataFrame with this copy." - ) - ) + " and join the DataFrame with this copy.")) } test("DF_RANDOM_SPLIT_WEIGHT_INVALID") { val ex = ErrorMessage.DF_RANDOM_SPLIT_WEIGHT_INVALID() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0105"))) assert( - ex.message.startsWith( - "Error Code: 0105, Error message: " + - "The specified weights for randomSplit() must not be negative numbers." - ) - ) + ex.message.startsWith("Error Code: 0105, Error message: " + + "The specified weights for randomSplit() must not be negative numbers.")) } test("DF_RANDOM_SPLIT_WEIGHT_ARRAY_EMPTY") { val ex = ErrorMessage.DF_RANDOM_SPLIT_WEIGHT_ARRAY_EMPTY() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0106"))) assert( - ex.message.startsWith( - "Error Code: 0106, Error message: " + - "You cannot pass an empty array of weights to randomSplit()." - ) - ) + ex.message.startsWith("Error Code: 0106, Error message: " + + "You cannot pass an empty array of weights to randomSplit().")) } test("DF_FLATTEN_UNSUPPORTED_INPUT_MODE") { @@ -109,31 +88,23 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0107, Error message: " + "Unsupported input mode String. For the mode parameter in flatten(), " + - "you must specify OBJECT, ARRAY, or BOTH." - ) - ) + "you must specify OBJECT, ARRAY, or BOTH.")) } test("DF_CANNOT_RESOLVE_COLUMN_NAME") { val ex = ErrorMessage.DF_CANNOT_RESOLVE_COLUMN_NAME("col1", Seq("c1", "c3")) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0108"))) assert( - ex.message.startsWith( - "Error Code: 0108, Error message: " + - "The DataFrame does not contain the column named 'col1' and the valid names are c1, c3." - ) - ) + ex.message.startsWith("Error Code: 0108, Error message: " + + "The DataFrame does not contain the column named 'col1' and the valid names are c1, c3.")) } test("DF_MUST_PROVIDE_SCHEMA_FOR_READING_FILE") { val ex = ErrorMessage.DF_MUST_PROVIDE_SCHEMA_FOR_READING_FILE() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0109"))) assert( - ex.message.startsWith( - "Error Code: 0109, Error message: " + - "You must call DataFrameReader.schema() and specify the schema for the file." - ) - ) + ex.message.startsWith("Error Code: 0109, Error message: " + + "You must call DataFrameReader.schema() and specify the schema for the file.")) } test("DF_CROSS_TAB_COUNT_TOO_LARGE") { @@ -143,9 +114,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0110, Error message: " + "The number of distinct values in the second input column (1) " + - "exceeds the maximum number of distinct values allowed (2)." - ) - ) + "exceeds the maximum number of distinct values allowed (2).")) } test("DF_DATAFRAME_IS_NOT_QUALIFIED_FOR_SCALAR_QUERY") { @@ -155,9 +124,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0111, Error message: " + "The DataFrame passed in to this function must have only one output column. " + - "This DataFrame has 2 output columns: c1, c2" - ) - ) + "This DataFrame has 2 output columns: c1, c2")) } test("DF_PIVOT_ONLY_SUPPORT_ONE_AGG_EXPR") { @@ -167,86 +134,63 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0112, Error message: " + "You can apply only one aggregate expression to a RelationalGroupedDataFrame " + - "returned by the pivot() method." - ) - ) + "returned by the pivot() method.")) } test("DF_FUNCTION_ARGS_CANNOT_BE_EMPTY") { val ex = ErrorMessage.DF_FUNCTION_ARGS_CANNOT_BE_EMPTY("myFunc") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0113"))) assert( - ex.message.startsWith( - "Error Code: 0113, Error message: " + - "You must pass a Seq of one or more Columns to function: myFunc" - ) - ) + ex.message.startsWith("Error Code: 0113, Error message: " + + "You must pass a Seq of one or more Columns to function: myFunc")) } test("DF_WINDOW_BOUNDARY_START_INVALID") { val ex = ErrorMessage.DF_WINDOW_BOUNDARY_START_INVALID(123) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0114"))) assert( - ex.message.startsWith( - "Error Code: 0114, Error message: " + - "The starting point for the window frame is not a valid integer: 123." - ) - ) + ex.message.startsWith("Error Code: 0114, Error message: " + + "The starting point for the window frame is not a valid integer: 123.")) } test("DF_WINDOW_BOUNDARY_END_INVALID") { val ex = ErrorMessage.DF_WINDOW_BOUNDARY_END_INVALID(123) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0115"))) assert( - ex.message.startsWith( - "Error Code: 0115, Error message: " + - "The ending point for the window frame is not a valid integer: 123." - ) - ) + ex.message.startsWith("Error Code: 0115, Error message: " + + "The ending point for the window frame is not a valid integer: 123.")) } test("DF_JOIN_INVALID_JOIN_TYPE") { val ex = ErrorMessage.DF_JOIN_INVALID_JOIN_TYPE("inner", "left, right") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0116"))) assert( - ex.message.startsWith( - "Error Code: 0116, Error message: " + - "Unsupported join type 'inner'. Supported join types include: left, right." - ) - ) + ex.message.startsWith("Error Code: 0116, Error message: " + + "Unsupported join type 'inner'. Supported join types include: left, right.")) } test("DF_JOIN_INVALID_NATURAL_JOIN_TYPE") { val ex = ErrorMessage.DF_JOIN_INVALID_NATURAL_JOIN_TYPE("leftanti") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0117"))) assert( - ex.message.startsWith( - "Error Code: 0117, Error message: " + - "Unsupported natural join type 'leftanti'." - ) - ) + ex.message.startsWith("Error Code: 0117, Error message: " + + "Unsupported natural join type 'leftanti'.")) } test("DF_JOIN_INVALID_USING_JOIN_TYPE") { val ex = ErrorMessage.DF_JOIN_INVALID_USING_JOIN_TYPE("leftanti") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0118"))) assert( - ex.message.startsWith( - "Error Code: 0118, Error message: " + - "Unsupported using join type 'leftanti'." - ) - ) + ex.message.startsWith("Error Code: 0118, Error message: " + + "Unsupported using join type 'leftanti'.")) } test("DF_RANGE_STEP_CANNOT_BE_ZERO") { val ex = ErrorMessage.DF_RANGE_STEP_CANNOT_BE_ZERO() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0119"))) assert( - ex.message.startsWith( - "Error Code: 0119, Error message: " + - "The step for range() cannot be 0." - ) - ) + ex.message.startsWith("Error Code: 0119, Error message: " + + "The step for range() cannot be 0.")) } test("DF_CANNOT_RENAME_COLUMN_BECAUSE_NOT_EXIST") { @@ -256,9 +200,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0120, Error message: " + "Unable to rename the column oldName as newName because" + - " this DataFrame doesn't have a column named oldName." - ) - ) + " this DataFrame doesn't have a column named oldName.")) } test("DF_CANNOT_RENAME_COLUMN_BECAUSE_MULTIPLE_EXIST") { @@ -268,9 +210,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0121, Error message: " + "Unable to rename the column oldName as newName because" + - " this DataFrame has 3 columns named oldName." - ) - ) + " this DataFrame has 3 columns named oldName.")) } test("DF_COPY_INTO_CANNOT_CREATE_TABLE") { @@ -280,53 +220,39 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0122, Error message: " + "Cannot create the target table table_123 because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) } test("DF_WITH_COLUMNS_INPUT_NAMES_NOT_MATCH_VALUES") { val ex = ErrorMessage.DF_WITH_COLUMNS_INPUT_NAMES_NOT_MATCH_VALUES(10, 5) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0123"))) assert( - ex.message.startsWith( - "Error Code: 0123, Error message: " + - "The number of column names (10) does not match the number of values (5)." - ) - ) + ex.message.startsWith("Error Code: 0123, Error message: " + + "The number of column names (10) does not match the number of values (5).")) } test("DF_WITH_COLUMNS_INPUT_NAMES_CONTAINS_DUPLICATES") { val ex = ErrorMessage.DF_WITH_COLUMNS_INPUT_NAMES_CONTAINS_DUPLICATES assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0124"))) assert( - ex.message.startsWith( - "Error Code: 0124, Error message: " + - "The same column name is used multiple times in the colNames parameter." - ) - ) + ex.message.startsWith("Error Code: 0124, Error message: " + + "The same column name is used multiple times in the colNames parameter.")) } test("DF_COLUMN_LIST_OF_GENERATOR_CANNOT_BE_EMPTY") { val ex = ErrorMessage.DF_COLUMN_LIST_OF_GENERATOR_CANNOT_BE_EMPTY() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0125"))) assert( - ex.message.startsWith( - "Error Code: 0125, Error message: " + - "The column list of generator function can not be empty." - ) - ) + ex.message.startsWith("Error Code: 0125, Error message: " + + "The column list of generator function can not be empty.")) } test("DF_WRITER_INVALID_OPTION_NAME") { val ex = ErrorMessage.DF_WRITER_INVALID_OPTION_NAME("myOption", "table") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0126"))) assert( - ex.message.startsWith( - "Error Code: 0126, Error message: " + - "DataFrameWriter doesn't support option 'myOption' when writing to a table." - ) - ) + ex.message.startsWith("Error Code: 0126, Error message: " + + "DataFrameWriter doesn't support option 'myOption' when writing to a table.")) } test("DF_WRITER_INVALID_OPTION_VALUE") { @@ -336,9 +262,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0127, Error message: " + "DataFrameWriter doesn't support to set option 'myOption' as 'myValue'" + - " when writing to a table." - ) - ) + " when writing to a table.")) } test("DF_WRITER_INVALID_OPTION_NAME_FOR_MODE") { @@ -347,27 +271,19 @@ class ErrorMessageSuite extends FunSuite { "myOption", "myValue", "Overwrite", - "table" - ) + "table") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0128"))) - assert( - ex.message.startsWith( - "Error Code: 0128, Error message: " + - "DataFrameWriter doesn't support to set option 'myOption' as 'myValue' in 'Overwrite' mode" + - " when writing to a table." - ) - ) + assert(ex.message.startsWith("Error Code: 0128, Error message: " + + "DataFrameWriter doesn't support to set option 'myOption' as 'myValue' in 'Overwrite' mode" + + " when writing to a table.")) } test("DF_WRITER_INVALID_MODE") { val ex = ErrorMessage.DF_WRITER_INVALID_MODE("Append", "file") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0129"))) assert( - ex.message.startsWith( - "Error Code: 0129, Error message: " + - "DataFrameWriter doesn't support mode 'Append' when writing to a file." - ) - ) + ex.message.startsWith("Error Code: 0129, Error message: " + + "DataFrameWriter doesn't support mode 'Append' when writing to a file.")) } test("DF_JOIN_WITH_WRONG_ARGUMENT") { @@ -377,53 +293,39 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0130, Error message: " + "Unsupported join operations, Dataframes can join with other Dataframes" + - " or TableFunctions only" - ) - ) + " or TableFunctions only")) } test("DF_MORE_THAN_ONE_TF_IN_SELECT") { val ex = ErrorMessage.DF_MORE_THAN_ONE_TF_IN_SELECT() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0131"))) assert( - ex.message.startsWith( - "Error Code: 0131, Error message: " + - "At most one table function can be called inside select() function" - ) - ) + ex.message.startsWith("Error Code: 0131, Error message: " + + "At most one table function can be called inside select() function")) } test("DF_ALIAS_DUPLICATES") { val ex = ErrorMessage.DF_ALIAS_DUPLICATES(Set("a", "b")) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0132"))) assert( - ex.message.startsWith( - "Error Code: 0132, Error message: " + - "Duplicated dataframe alias defined: a, b" - ) - ) + ex.message.startsWith("Error Code: 0132, Error message: " + + "Duplicated dataframe alias defined: a, b")) } test("UDF_INCORRECT_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0200"))) assert( - ex.message.startsWith( - "Error Code: 0200, Error message: " + - "Incorrect number of arguments passed to the UDF: Expected: 1, Found: 2" - ) - ) + ex.message.startsWith("Error Code: 0200, Error message: " + + "Incorrect number of arguments passed to the UDF: Expected: 1, Found: 2")) } test("UDF_FOUND_UNREGISTERED_UDF") { val ex = ErrorMessage.UDF_FOUND_UNREGISTERED_UDF() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0201"))) assert( - ex.message.startsWith( - "Error Code: 0201, Error message: " + - "Attempted to call an unregistered UDF. You must register the UDF before calling it." - ) - ) + ex.message.startsWith("Error Code: 0201, Error message: " + + "Attempted to call an unregistered UDF. You must register the UDF before calling it.")) } test("UDF_CANNOT_DETECT_UDF_FUNCION_CLASS") { @@ -434,9 +336,7 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0202, Error message: " + "Unable to detect the location of the enclosing class of the UDF. " + "Call session.addDependency, and pass in the path to the directory or JAR file " + - "containing the compiled class file." - ) - ) + "containing the compiled class file.")) } test("UDF_NEED_SCALA_2_12_OR_LAMBDAFYMETHOD") { @@ -446,20 +346,15 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0203, Error message: " + "Unable to clean the closure. You must use a supported version of Scala (2.12+) or " + - "specify the Scala compiler flag -Dlambdafymethod." - ) - ) + "specify the Scala compiler flag -Dlambdafymethod.")) } test("UDF_RETURN_STATEMENT_IN_CLOSURE_EXCEPTION") { val ex = ErrorMessage.UDF_RETURN_STATEMENT_IN_CLOSURE_EXCEPTION() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0204"))) assert( - ex.message.startsWith( - "Error Code: 0204, Error message: " + - "You cannot include a return statement in a closure." - ) - ) + ex.message.startsWith("Error Code: 0204, Error message: " + + "You cannot include a return statement in a closure.")) } test("UDF_CANNOT_FIND_JAVA_COMPILER") { @@ -469,20 +364,15 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0205, Error message: " + "Cannot find the JDK. For your development environment, set the JAVA_HOME environment " + - "variable to the directory where you installed a supported version of the JDK." - ) - ) + "variable to the directory where you installed a supported version of the JDK.")) } test("UDF_ERROR_IN_COMPILING_CODE") { val ex = ErrorMessage.UDF_ERROR_IN_COMPILING_CODE("'v1' is not defined") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0206"))) assert( - ex.message.startsWith( - "Error Code: 0206, Error message: " + - "Error compiling your UDF code: 'v1' is not defined" - ) - ) + ex.message.startsWith("Error Code: 0206, Error message: " + + "Error compiling your UDF code: 'v1' is not defined")) } test("UDF_NO_DEFAULT_SESSION_FOUND") { @@ -492,9 +382,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0207, Error message: " + "No default Session found. Use .udf.registerTemporary()" + - " to explicitly refer to a session." - ) - ) + " to explicitly refer to a session.")) } test("UDF_INVALID_UDTF_COLUMN_NAME") { @@ -504,9 +392,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0208, Error message: " + "You cannot use 'invalid name' as an UDTF output schema name " + - "which needs to be a valid Java identifier." - ) - ) + "which needs to be a valid Java identifier.")) } test("UDF_CANNOT_INFER_MAP_TYPES") { @@ -517,9 +403,7 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0209, Error message: " + "Cannot determine the input types because the process() method passes in" + " Map arguments. In your JavaUDTF class, implement the inputSchema() method to" + - " describe the input types." - ) - ) + " describe the input types.")) } test("UDF_CANNOT_INFER_MULTIPLE_PROCESS") { @@ -530,20 +414,15 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0210, Error message: " + "Cannot determine the input types because the process() method has multiple signatures" + " with 3 arguments. In your JavaUDTF class, implement the inputSchema() method to" + - " describe the input types." - ) - ) + " describe the input types.")) } test("UDF_INCORRECT_SPROC_ARGS_NUMBER") { val ex = ErrorMessage.UDF_INCORRECT_SPROC_ARGS_NUMBER(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0211"))) assert( - ex.message.startsWith( - "Error Code: 0211, Error message: " + - "Incorrect number of arguments passed to the SProc: Expected: 1, Found: 2" - ) - ) + ex.message.startsWith("Error Code: 0211, Error message: " + + "Incorrect number of arguments passed to the SProc: Expected: 1, Found: 2")) } test("UDF_CANNOT_ACCEPT_MANY_DF_COLS") { @@ -553,9 +432,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0212, Error message: " + "Session.tableFunction does not support columns from more than one dataframe as input." + - " Join these dataframes before using the function" - ) - ) + " Join these dataframes before using the function")) } test("UDF_UNEXPECTED_COLUMN_ORDER") { @@ -565,9 +442,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0213, Error message: " + "Dataframe resulting from table function has an unexpected column order." + - " Source DataFrame columns did not come first." - ) - ) + " Source DataFrame columns did not come first.")) } test("PLAN_LAST_QUERY_RETURN_RESULTSET") { @@ -577,20 +452,15 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0300, Error message: " + "Internal error: the execution for the last query in the snowflake plan " + - "doesn't return a ResultSet." - ) - ) + "doesn't return a ResultSet.")) } test("PLAN_ANALYZER_INVALID_NAME") { val ex = ErrorMessage.PLAN_ANALYZER_INVALID_IDENTIFIER("wrong_identifier") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0301"))) assert( - ex.message.startsWith( - "Error Code: 0301, Error message: " + - "Invalid identifier wrong_identifier" - ) - ) + ex.message.startsWith("Error Code: 0301, Error message: " + + "Invalid identifier wrong_identifier")) } test("PLAN_ANALYZER_UNSUPPORTED_VIEW_TYPE") { @@ -600,20 +470,15 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0302, Error message: " + "Internal Error: Only PersistedView and LocalTempView are supported. " + - "view type: wrong view type" - ) - ) + "view type: wrong view type")) } test("PLAN_SAMPLING_NEED_ONE_PARAMETER") { val ex = ErrorMessage.PLAN_SAMPLING_NEED_ONE_PARAMETER() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0303"))) assert( - ex.message.startsWith( - "Error Code: 0303, Error message: " + - "You must specify either the fraction of rows or the number of rows to sample." - ) - ) + ex.message.startsWith("Error Code: 0303, Error message: " + + "You must specify either the fraction of rows or the number of rows to sample.")) } test("PLAN_JOIN_NEED_USING_CLAUSE_OR_JOIN_CONDITION") { @@ -623,42 +488,31 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0304, Error message: " + "For the join, you must specify either the conditions for the join or " + - "the list of columns to use for the join (not both)." - ) - ) + "the list of columns to use for the join (not both).")) } test("PLAN_LEFT_SEMI_JOIN_NOT_SUPPORT_USING_CLAUSE") { val ex = ErrorMessage.PLAN_LEFT_SEMI_JOIN_NOT_SUPPORT_USING_CLAUSE() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0305"))) assert( - ex.message.startsWith( - "Error Code: 0305, Error message: " + - "Internal error: Unexpected Using clause in left semi join" - ) - ) + ex.message.startsWith("Error Code: 0305, Error message: " + + "Internal error: Unexpected Using clause in left semi join")) } test("PLAN_LEFT_ANTI_JOIN_NOT_SUPPORT_USING_CLAUSE") { val ex = ErrorMessage.PLAN_LEFT_ANTI_JOIN_NOT_SUPPORT_USING_CLAUSE() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0306"))) assert( - ex.message.startsWith( - "Error Code: 0306, Error message: " + - "Internal error: Unexpected Using clause in left anti join" - ) - ) + ex.message.startsWith("Error Code: 0306, Error message: " + + "Internal error: Unexpected Using clause in left anti join")) } test("PLAN_UNSUPPORTED_FILE_OPERATION_TYPE") { val ex = ErrorMessage.PLAN_UNSUPPORTED_FILE_OPERATION_TYPE() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0307"))) assert( - ex.message.startsWith( - "Error Code: 0307, Error message: " + - "Internal error: Unsupported file operation type" - ) - ) + ex.message.startsWith("Error Code: 0307, Error message: " + + "Internal error: Unsupported file operation type")) } test("PLAN_JDBC_REPORT_UNEXPECTED_ALIAS") { @@ -668,20 +522,15 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0308, Error message: " + "You can only define aliases for the root Columns in a DataFrame returned by " + - "select() and agg(). You cannot use aliases for Columns in expressions." - ) - ) + "select() and agg(). You cannot use aliases for Columns in expressions.")) } test("PLAN_JDBC_REPORT_INVALID_ID") { val ex = ErrorMessage.PLAN_JDBC_REPORT_INVALID_ID("id1") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0309"))) assert( - ex.message.startsWith( - "Error Code: 0309, Error message: " + - """The column specified in df("id1") is not present in the output of the DataFrame.""" - ) - ) + ex.message.startsWith("Error Code: 0309, Error message: " + + """The column specified in df("id1") is not present in the output of the DataFrame.""")) } test("PLAN_JDBC_REPORT_JOIN_AMBIGUOUS") { @@ -694,9 +543,7 @@ class ErrorMessageSuite extends FunSuite { "The column is present in both DataFrames used in the join. " + "To identify the DataFrame that you want to use in the reference, " + """use the syntax ("b") in join conditions and in select() calls """ + - "on the result of the join." - ) - ) + "on the result of the join.")) } test("PLAN_COPY_DONT_SUPPORT_SKIP_LOADED_FILES") { @@ -707,31 +554,23 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0311, Error message: " + "The COPY option 'FORCE = false' is not supported by the Snowpark library. " + "The Snowflake library loads all files, even if the files have been loaded previously " + - "and have not changed since they were loaded." - ) - ) + "and have not changed since they were loaded.")) } test("PLAN_CANNOT_CREATE_LITERAL") { val ex = ErrorMessage.PLAN_CANNOT_CREATE_LITERAL("MyClass", "myValue") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0312"))) assert( - ex.message.startsWith( - "Error Code: 0312, Error message: " + - "Cannot create a Literal for MyClass(myValue)" - ) - ) + ex.message.startsWith("Error Code: 0312, Error message: " + + "Cannot create a Literal for MyClass(myValue)")) } test("PLAN_UNSUPPORTED_FILE_FORMAT_TYPE") { val ex = ErrorMessage.PLAN_UNSUPPORTED_FILE_FORMAT_TYPE("unknown_type") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0313"))) assert( - ex.message.startsWith( - "Error Code: 0313, Error message: " + - "Internal error: unsupported file format type: 'unknown_type'." - ) - ) + ex.message.startsWith("Error Code: 0313, Error message: " + + "Internal error: unsupported file format type: 'unknown_type'.")) } test("PLAN_IN_EXPRESSION_UNSUPPORTED_VALUE") { @@ -742,20 +581,15 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0314, Error message: " + "'column_A' is not supported for the values parameter of the function in()." + " You must either specify a sequence of literals or a DataFrame that" + - " represents a subquery." - ) - ) + " represents a subquery.")) } test("PLAN_IN_EXPRESSION_INVALID_VALUE_COUNT") { val ex = ErrorMessage.PLAN_IN_EXPRESSION_INVALID_VALUE_COUNT(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0315"))) assert( - ex.message.startsWith( - "Error Code: 0315, Error message: " + - "For the in() function, the number of values 1 does not match the number of columns 2." - ) - ) + ex.message.startsWith("Error Code: 0315, Error message: " + + "For the in() function, the number of values 1 does not match the number of columns 2.")) } test("PLAN_COPY_INVALID_COLUMN_NAME_SIZE") { @@ -765,20 +599,15 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0316, Error message: " + "Number of column names provided to copy into does not match the number of " + - "transformations provided. Number of column names: 1, number of transformations: 2." - ) - ) + "transformations provided. Number of column names: 1, number of transformations: 2.")) } test("PLAN_CANNOT_EXECUTE_IN_ASYNC_MODE") { val ex = ErrorMessage.PLAN_CANNOT_EXECUTE_IN_ASYNC_MODE("Plan(q1, q2)") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0317"))) assert( - ex.message.startsWith( - "Error Code: 0317, Error message: " + - "Cannot execute the following plan asynchronously:\nPlan(q1, q2)" - ) - ) + ex.message.startsWith("Error Code: 0317, Error message: " + + "Cannot execute the following plan asynchronously:\nPlan(q1, q2)")) } test("PLAN_QUERY_IS_STILL_RUNNING") { @@ -790,20 +619,15 @@ class ErrorMessageSuite extends FunSuite { "The query with the ID qid_123 is still running and has the current status RUNNING." + " The function call has been running for 100 seconds, which exceeds the maximum number" + " of seconds to wait for the results. Use the `maxWaitTimeInSeconds` argument" + - " to increase the number of seconds to wait." - ) - ) + " to increase the number of seconds to wait.")) } test("PLAN_CANNOT_SUPPORT_TYPE_FOR_ASYNC_JOB") { val ex = ErrorMessage.PLAN_CANNOT_SUPPORT_TYPE_FOR_ASYNC_JOB("MyType") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0319"))) assert( - ex.message.startsWith( - "Error Code: 0319, Error message: " + - "Internal Error: Unsupported type 'MyType' for TypedAsyncJob." - ) - ) + ex.message.startsWith("Error Code: 0319, Error message: " + + "Internal Error: Unsupported type 'MyType' for TypedAsyncJob.")) } test("PLAN_CANNOT_GET_ASYNC_JOB_RESULT") { @@ -813,31 +637,23 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0320, Error message: " + "Internal Error: Cannot retrieve the value for the type 'MyType'" + - " in the function 'myFunc'." - ) - ) + " in the function 'myFunc'.")) } test("PLAN_MERGE_RETURN_WRONG_ROWS") { val ex = ErrorMessage.PLAN_MERGE_RETURN_WRONG_ROWS(1, 2) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0321"))) assert( - ex.message.startsWith( - "Error Code: 0321, Error message: " + - "Internal error: Merge statement should return 1 row but returned 2 rows" - ) - ) + ex.message.startsWith("Error Code: 0321, Error message: " + + "Internal error: Merge statement should return 1 row but returned 2 rows")) } test("MISC_CANNOT_CAST_VALUE") { val ex = ErrorMessage.MISC_CANNOT_CAST_VALUE("MyClass", "value123", "Int") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0400"))) assert( - ex.message.startsWith( - "Error Code: 0400, Error message: " + - "Cannot cast MyClass(value123) to Int." - ) - ) + ex.message.startsWith("Error Code: 0400, Error message: " + + "Cannot cast MyClass(value123) to Int.")) } test("MISC_CANNOT_FIND_CURRENT_DB_OR_SCHEMA") { @@ -848,108 +664,79 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0401, Error message: " + "The DB is not set for the current session. To set this, either run " + "session.sql(\"USE DB\").collect() or set the DB connection property in " + - "the Map or properties file that you specify when creating a session." - ) - ) + "the Map or properties file that you specify when creating a session.")) } test("MISC_QUERY_IS_CANCELLED") { val ex = ErrorMessage.MISC_QUERY_IS_CANCELLED() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0402"))) assert( - ex.message.startsWith( - "Error Code: 0402, Error message: " + - "The query has been cancelled by the user." - ) - ) + ex.message.startsWith("Error Code: 0402, Error message: " + + "The query has been cancelled by the user.")) } test("MISC_INVALID_CLIENT_VERSION") { val ex = ErrorMessage.MISC_INVALID_CLIENT_VERSION("0.6.x") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0403"))) assert( - ex.message.startsWith( - "Error Code: 0403, Error message: " + - "Invalid client version string 0.6.x" - ) - ) + ex.message.startsWith("Error Code: 0403, Error message: " + + "Invalid client version string 0.6.x")) } test("MISC_INVALID_CLOSURE_CLEANER_PARAMETER") { val ex = ErrorMessage.MISC_INVALID_CLOSURE_CLEANER_PARAMETER("my_parameter") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0404"))) assert( - ex.message.startsWith( - "Error Code: 0404, Error message: " + - "The parameter my_parameter must be 'always', 'never', or 'repl_only'." - ) - ) + ex.message.startsWith("Error Code: 0404, Error message: " + + "The parameter my_parameter must be 'always', 'never', or 'repl_only'.")) } test("MISC_INVALID_CONNECTION_STRING") { val ex = ErrorMessage.MISC_INVALID_CONNECTION_STRING("invalid_connection_string") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0405"))) assert( - ex.message.startsWith( - "Error Code: 0405, Error message: " + - "Invalid connection string invalid_connection_string" - ) - ) + ex.message.startsWith("Error Code: 0405, Error message: " + + "Invalid connection string invalid_connection_string")) } test("MISC_MULTIPLE_VALUES_RETURNED_FOR_PARAMETER") { val ex = ErrorMessage.MISC_MULTIPLE_VALUES_RETURNED_FOR_PARAMETER("myParameter") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0406"))) assert( - ex.message.startsWith( - "Error Code: 0406, Error message: " + - "The server returned multiple values for the parameter myParameter." - ) - ) + ex.message.startsWith("Error Code: 0406, Error message: " + + "The server returned multiple values for the parameter myParameter.")) } test("MISC_NO_VALUES_RETURNED_FOR_PARAMETER") { val ex = ErrorMessage.MISC_NO_VALUES_RETURNED_FOR_PARAMETER("myParameter") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0407"))) assert( - ex.message.startsWith( - "Error Code: 0407, Error message: " + - "The server returned no value for the parameter myParameter." - ) - ) + ex.message.startsWith("Error Code: 0407, Error message: " + + "The server returned no value for the parameter myParameter.")) } test("MISC_SESSION_EXPIRED") { val ex = ErrorMessage.MISC_SESSION_EXPIRED("session expired!") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0408"))) assert( - ex.message.startsWith( - "Error Code: 0408, Error message: " + - "Your Snowpark session has expired. You must recreate your session.\nsession expired!" - ) - ) + ex.message.startsWith("Error Code: 0408, Error message: " + + "Your Snowpark session has expired. You must recreate your session.\nsession expired!")) } test("MISC_NESTED_OPTION_TYPE_IS_NOT_SUPPORTED") { val ex = ErrorMessage.MISC_NESTED_OPTION_TYPE_IS_NOT_SUPPORTED() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0409"))) assert( - ex.message.startsWith( - "Error Code: 0409, Error message: " + - "You cannot use a nested Option type (e.g. Option[Option[Int]])." - ) - ) + ex.message.startsWith("Error Code: 0409, Error message: " + + "You cannot use a nested Option type (e.g. Option[Option[Int]]).")) } test("MISC_CANNOT_INFER_SCHEMA_FROM_TYPE") { val ex = ErrorMessage.MISC_CANNOT_INFER_SCHEMA_FROM_TYPE("MyClass") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0410"))) assert( - ex.message.startsWith( - "Error Code: 0410, Error message: " + - "Could not infer schema from data of type: MyClass" - ) - ) + ex.message.startsWith("Error Code: 0410, Error message: " + + "Could not infer schema from data of type: MyClass")) } test("MISC_SCALA_VERSION_NOT_SUPPORTED") { @@ -959,64 +746,47 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0411, Error message: " + "Scala version 2.12.6 detected. Snowpark only supports Scala version 2.12 with " + - "the minor version 2.12.9 and higher." - ) - ) + "the minor version 2.12.9 and higher.")) } test("MISC_INVALID_OBJECT_NAME") { val ex = ErrorMessage.MISC_INVALID_OBJECT_NAME("objName") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0412"))) assert( - ex.message.startsWith( - "Error Code: 0412, Error message: " + - "The object name 'objName' is invalid." - ) - ) + ex.message.startsWith("Error Code: 0412, Error message: " + + "The object name 'objName' is invalid.")) } test("MISC_SP_ACTIVE_SESSION_RESET") { val ex = ErrorMessage.MISC_SP_ACTIVE_SESSION_RESET() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0413"))) assert( - ex.message.startsWith( - "Error Code: 0413, Error message: " + - "Unexpected stored procedure active session reset." - ) - ) + ex.message.startsWith("Error Code: 0413, Error message: " + + "Unexpected stored procedure active session reset.")) } test("MISC_SESSION_HAS_BEEN_CLOSED") { val ex = ErrorMessage.MISC_SESSION_HAS_BEEN_CLOSED() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0414"))) assert( - ex.message.startsWith( - "Error Code: 0414, Error message: " + - "Cannot perform this operation because the session has been closed." - ) - ) + ex.message.startsWith("Error Code: 0414, Error message: " + + "Cannot perform this operation because the session has been closed.")) } test("MISC_FAILED_CLOSE_SESSION") { val ex = ErrorMessage.MISC_FAILED_CLOSE_SESSION("this error message") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0415"))) assert( - ex.message.startsWith( - "Error Code: 0415, Error message: " + - "Failed to close this session. The error is: this error message" - ) - ) + ex.message.startsWith("Error Code: 0415, Error message: " + + "Failed to close this session. The error is: this error message")) } test("MISC_CANNOT_CLOSE_STORED_PROC_SESSION") { val ex = ErrorMessage.MISC_CANNOT_CLOSE_STORED_PROC_SESSION() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0416"))) assert( - ex.message.startsWith( - "Error Code: 0416, Error message: " + - "Cannot close this session because it is used by stored procedure." - ) - ) + ex.message.startsWith("Error Code: 0416, Error message: " + + "Cannot close this session because it is used by stored procedure.")) } test("MISC_INVALID_INT_PARAMETER") { @@ -1024,38 +794,29 @@ class ErrorMessageSuite extends FunSuite { "abc", SnowparkRequestTimeoutInSeconds, MIN_REQUEST_TIMEOUT_IN_SECONDS, - MAX_REQUEST_TIMEOUT_IN_SECONDS - ) + MAX_REQUEST_TIMEOUT_IN_SECONDS) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0418"))) assert( ex.message.startsWith( "Error Code: 0418, Error message: " + "Invalid value abc for parameter snowpark_request_timeout_in_seconds. " + - "Please input an integer value that is between 0 and 604800." - ) - ) + "Please input an integer value that is between 0 and 604800.")) } test("MISC_REQUEST_TIMEOUT") { val ex = ErrorMessage.MISC_REQUEST_TIMEOUT("UDF jar uploading", 10) assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0419"))) assert( - ex.message.startsWith( - "Error Code: 0419, Error message: " + - "UDF jar uploading exceeds the maximum allowed time: 10 second(s)." - ) - ) + ex.message.startsWith("Error Code: 0419, Error message: " + + "UDF jar uploading exceeds the maximum allowed time: 10 second(s).")) } test("MISC_INVALID_RSA_PRIVATE_KEY") { val ex = ErrorMessage.MISC_INVALID_RSA_PRIVATE_KEY("test message") assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0420"))) assert( - ex.message.startsWith( - "Error Code: 0420, Error message: Invalid RSA private key." + - " The error is: test message" - ) - ) + ex.message.startsWith("Error Code: 0420, Error message: Invalid RSA private key." + + " The error is: test message")) } test("MISC_INVALID_STAGE_LOCATION") { @@ -1063,9 +824,7 @@ class ErrorMessageSuite extends FunSuite { assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0421"))) assert( ex.message.startsWith( - "Error Code: 0421, Error message: Invalid stage location: stage. Reason: test message." - ) - ) + "Error Code: 0421, Error message: Invalid stage location: stage. Reason: test message.")) } test("MISC_NO_SERVER_VALUE_NO_DEFAULT_FOR_PARAMETER") { @@ -1074,20 +833,15 @@ class ErrorMessageSuite extends FunSuite { assert( ex.message.startsWith( "Error Code: 0422, Error message: Internal error: Server fetching is disabled" + - " for the parameter someParameter and there is no default value for it." - ) - ) + " for the parameter someParameter and there is no default value for it.")) } test("MISC_INVALID_TABLE_FUNCTION_INPUT") { val ex = ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT() assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0423"))) assert( - ex.message.startsWith( - "Error Code: 0423, Error message: Invalid input argument, " + - "Session.tableFunction only supports table function arguments" - ) - ) + ex.message.startsWith("Error Code: 0423, Error message: Invalid input argument, " + + "Session.tableFunction only supports table function arguments")) } test("MISC_INVALID_EXPLODE_ARGUMENT_TYPE") { @@ -1098,9 +852,7 @@ class ErrorMessageSuite extends FunSuite { "Error Code: 0424, Error message: " + "Invalid input argument type, the input argument type of " + "Explode function should be either Map or Array types.\n" + - "The input argument type: Integer" - ) - ) + "The input argument type: Integer")) } test("MISC_UNSUPPORTED_GEOMETRY_FORMAT") { @@ -1110,9 +862,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0425, Error message: " + "Unsupported Geometry output format: KWT." + - " Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON." - ) - ) + " Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.")) } test("MISC_INVALID_INPUT_QUERY_TAG") { @@ -1122,9 +872,7 @@ class ErrorMessageSuite extends FunSuite { ex.message.startsWith( "Error Code: 0426, Error message: " + "The given query tag must be a valid JSON string. " + - "Ensure it's correctly formatted as JSON." - ) - ) + "Ensure it's correctly formatted as JSON.")) } test("MISC_INVALID_CURRENT_QUERY_TAG") { @@ -1133,9 +881,7 @@ class ErrorMessageSuite extends FunSuite { assert( ex.message.startsWith( "Error Code: 0427, Error message: The query tag of the current session " + - "must be a valid JSON string. Current query tag: myTag" - ) - ) + "must be a valid JSON string. Current query tag: myTag")) } test("MISC_FAILED_TO_SERIALIZE_QUERY_TAG") { @@ -1143,8 +889,6 @@ class ErrorMessageSuite extends FunSuite { assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0428"))) assert( ex.message.startsWith( - "Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string." - ) - ) + "Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string.")) } } diff --git a/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala b/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala index fecba46f..374105f4 100644 --- a/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala @@ -134,8 +134,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { None } else { Some(invoked.get ++ invokedSet.get) - } - ) + }) } val exp = func(exprs) @@ -174,24 +173,20 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { childrenChecker(3, data => WithinGroup(data.head, data.tail)) childrenChecker( 5, - data => UpdateMergeExpression(Some(data.head), Map(data(1) -> data(2), data(3) -> data(4))) - ) + data => UpdateMergeExpression(Some(data.head), Map(data(1) -> data(2), data(3) -> data(4)))) childrenChecker( 4, - data => UpdateMergeExpression(None, Map(data(1) -> data(2), data(3) -> data.head)) - ) + data => UpdateMergeExpression(None, Map(data(1) -> data(2), data(3) -> data.head))) unaryChecker(x => DeleteMergeExpression(Some(x))) emptyChecker(DeleteMergeExpression(None)) childrenChecker( 5, - data => InsertMergeExpression(Some(data.head), Seq(data(1), data(2)), Seq(data(3), data(4))) - ) + data => InsertMergeExpression(Some(data.head), Seq(data(1), data(2)), Seq(data(3), data(4)))) childrenChecker( 4, - data => InsertMergeExpression(None, Seq(data(1), data(2)), Seq(data(3), data.head)) - ) + data => InsertMergeExpression(None, Seq(data(1), data(2)), Seq(data(3), data.head))) childrenChecker(2, Cube) childrenChecker(2, Rollup) @@ -199,8 +194,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { childrenChecker( 5, - data => CaseWhen(Seq((data.head, data(1)), (data(2), data(3))), Some(data(4))) - ) + data => CaseWhen(Seq((data.head, data(1)), (data(2), data(3))), Some(data(4)))) childrenChecker(4, data => CaseWhen(Seq((data.head, data(1)), (data(2), data(3))), None)) childrenChecker(2, MultipleExpression) @@ -299,56 +293,46 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { val windowSpec1 = WindowSpecDefinition(Seq(col4, col5), Seq(order1), windowFrame1) assert( windowSpec1.children.map(_.toString).toSet == - Set(col4.toString, col5.toString, order1.toString, windowFrame1.toString) - ) + Set(col4.toString, col5.toString, order1.toString, windowFrame1.toString)) assert( - windowSpec1.dependentColumnNames.contains(Set("\"F\"", "\"G\"", "\"E\"", "\"A\"", "\"B\"")) - ) + windowSpec1.dependentColumnNames.contains(Set("\"F\"", "\"G\"", "\"E\"", "\"A\"", "\"B\""))) val windowSpec2 = WindowSpecDefinition(Seq(col4, col5), Seq(order1), windowFrame2) assert( windowSpec2.children.map(_.toString).toSet == - Set(col4.toString, col5.toString, order1.toString, windowFrame2.toString) - ) + Set(col4.toString, col5.toString, order1.toString, windowFrame2.toString)) assert(windowSpec2.dependentColumnNames.isEmpty) val windowSpec3 = WindowSpecDefinition(Seq(col4, col5), Seq(order2), windowFrame1) assert( windowSpec3.children.map(_.toString).toSet == - Set(col4.toString, col5.toString, order2.toString, windowFrame1.toString) - ) + Set(col4.toString, col5.toString, order2.toString, windowFrame1.toString)) assert(windowSpec3.dependentColumnNames.isEmpty) val windowSpec4 = WindowSpecDefinition(Seq(col4, unresolvedCol2), Seq(order1), windowFrame1) assert( windowSpec4.children.map(_.toString).toSet == - Set(col4.toString, unresolvedCol2.toString, order1.toString, windowFrame1.toString) - ) + Set(col4.toString, unresolvedCol2.toString, order1.toString, windowFrame1.toString)) assert(windowSpec4.dependentColumnNames.isEmpty) val window1 = WindowExpression(col6, windowSpec1) assert( window1.children.map(_.toString).toSet == - Set(col6.toString, windowSpec1.toString) - ) + Set(col6.toString, windowSpec1.toString)) assert( window1.dependentColumnNames.contains( - Set("\"F\"", "\"G\"", "\"E\"", "\"A\"", "\"B\"", "\"H\"") - ) - ) + Set("\"F\"", "\"G\"", "\"E\"", "\"A\"", "\"B\"", "\"H\""))) val window2 = WindowExpression(col6, windowSpec2) assert( window2.children.map(_.toString).toSet == - Set(col6.toString, windowSpec2.toString) - ) + Set(col6.toString, windowSpec2.toString)) assert(window2.dependentColumnNames.isEmpty) val window3 = WindowExpression(unresolvedCol1, windowSpec1) assert( window3.children.map(_.toString).toSet == - Set(unresolvedCol1.toString, windowSpec1.toString) - ) + Set(unresolvedCol1.toString, windowSpec1.toString)) assert(window3.dependentColumnNames.isEmpty) } @@ -358,8 +342,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { dataSize: Int, // input: a list of generated child expressions, // output: a reference to the expression being tested - func: Seq[Expression] => Expression - ): Unit = { + func: Seq[Expression] => Expression): Unit = { val args: Seq[Literal] = (0 until dataSize) .map(functions.lit) .map(_.expr.asInstanceOf[Literal]) @@ -406,37 +389,31 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { analyzerChecker(2, internal.analyzer.TableFunction("dummy", _)) analyzerChecker( 2, - data => NamedArgumentsTableFunction("dummy", Map("a" -> data.head, "b" -> data(1))) - ) + data => NamedArgumentsTableFunction("dummy", Map("a" -> data.head, "b" -> data(1)))) analyzerChecker( 4, - data => GroupingSetsExpression(Seq(Set(data.head, data(1)), Set(data(2), data(3)))) - ) + data => GroupingSetsExpression(Seq(Set(data.head, data(1)), Set(data(2), data(3))))) analyzerChecker(2, SnowflakeUDF("dummy", _, IntegerType)) leafAnalyzerChecker(Literal(1, Some(IntegerType))) analyzerChecker(3, data => SortOrder(data.head, Ascending, NullsLast, data.tail.toSet)) analyzerChecker( 5, - data => UpdateMergeExpression(Some(data.head), Map(data(1) -> data(3), data(2) -> data(4))) - ) + data => UpdateMergeExpression(Some(data.head), Map(data(1) -> data(3), data(2) -> data(4)))) analyzerChecker( 4, - data => UpdateMergeExpression(None, Map(data.head -> data(2), data(1) -> data(3))) - ) + data => UpdateMergeExpression(None, Map(data.head -> data(2), data(1) -> data(3)))) unaryAnalyzerChecker(data => DeleteMergeExpression(Some(data))) leafAnalyzerChecker(DeleteMergeExpression(None)) analyzerChecker( 5, - data => InsertMergeExpression(Some(data.head), Seq(data(1), data(2)), Seq(data(3), data(4))) - ) + data => InsertMergeExpression(Some(data.head), Seq(data(1), data(2)), Seq(data(3), data(4)))) analyzerChecker( 4, - data => InsertMergeExpression(None, Seq(data.head, data(1)), Seq(data(2), data(3))) - ) + data => InsertMergeExpression(None, Seq(data.head, data(1)), Seq(data(2), data(3)))) analyzerChecker(2, Cube) analyzerChecker(2, Rollup) @@ -445,8 +422,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { analyzerChecker( 5, - data => CaseWhen(Seq((data.head, data(1)), (data(2), data(3))), Some(data(4))) - ) + data => CaseWhen(Seq((data.head, data(1)), (data(2), data(3))), Some(data(4)))) analyzerChecker(4, data => CaseWhen(Seq((data.head, data(1)), (data(2), data(3))))) @@ -507,7 +483,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { val att2 = Attribute("b", IntegerType) val func1: Expression => Expression = { case _: Attribute => att2 - case x => x + case x => x } val exp = Star(Seq(att1)) assert(exp.analyze(func1).children == Seq(att2)) @@ -525,15 +501,15 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { val exp = WindowSpecDefinition(Seq(lit0), Seq(order0), UnspecifiedFrame) val func1: Expression => Expression = { - case _: SortOrder => order1 + case _: SortOrder => order1 case _: WindowFrame => frame - case Literal(0, _) => lit1 - case x => x + case Literal(0, _) => lit1 + case x => x } val func2: Expression => Expression = { case _: WindowSpecDefinition => lit1 - case x => x + case x => x } val exp1 = exp.analyze(func1) @@ -550,14 +526,14 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { val window1 = WindowSpecDefinition(Seq(lit0), Seq(order), UnspecifiedFrame) val window2 = WindowSpecDefinition(Seq(lit1), Seq(order), UnspecifiedFrame) val func1: Expression => Expression = { - case _: Literal => lit1 + case _: Literal => lit1 case _: WindowSpecDefinition => window2 - case x => x + case x => x } val func2: Expression => Expression = { case _: WindowExpression => lit0 - case x => x + case x => x } val exp = WindowExpression(lit0, window1) @@ -595,8 +571,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .newCols .head .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan1 = WithColumns(Seq(attr3), child2) assert(plan1.aliasMap.isEmpty) @@ -606,8 +581,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .newCols .head .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("DropColumns - Analyzer") { @@ -619,8 +593,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .columns .head .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan1 = DropColumns(Seq(attr3), child2) assert(plan1.aliasMap.isEmpty) @@ -630,8 +603,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .columns .head .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("TableFunctionRelation - Analyzer") { @@ -715,8 +687,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { val plan = Sort(Seq(order), child1) assert(plan.aliasMap == map2) assert( - plan.analyzed.asInstanceOf[Sort].order.head.child.asInstanceOf[Attribute].name == "\"C\"" - ) + plan.analyzed.asInstanceOf[Sort].order.head.child.asInstanceOf[Attribute].name == "\"C\"") val plan1 = Sort(Seq(order), child2) assert(plan1.aliasMap.isEmpty) @@ -729,12 +700,10 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { assert(plan.aliasMap == map2) assert( plan.analyzed.asInstanceOf[LimitOnSort].order.head.child.asInstanceOf[Attribute].name - == "\"C\"" - ) + == "\"C\"") assert( plan.analyzed.asInstanceOf[LimitOnSort].limitExpr.asInstanceOf[Attribute].name - == "\"C\"" - ) + == "\"C\"") val plan1 = LimitOnSort(child2, attr3, Seq(order)) assert(plan1.aliasMap.isEmpty) @@ -750,8 +719,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .groupingExpressions .head .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan1 = Aggregate(Seq.empty, Seq(attr3), child1) assert(plan1.aliasMap == map2) @@ -770,26 +738,21 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { val plan1 = Pivot(attr1, Seq(attr3), Seq.empty, child1) assert(plan1.aliasMap == map2) assert( - plan1.analyzed.asInstanceOf[Pivot].pivotColumn.asInstanceOf[Attribute].name == "\"COL1\"" - ) + plan1.analyzed.asInstanceOf[Pivot].pivotColumn.asInstanceOf[Attribute].name == "\"COL1\"") assert( - plan1.analyzed.asInstanceOf[Pivot].pivotValues.head.asInstanceOf[Attribute].name == "\"C\"" - ) + plan1.analyzed.asInstanceOf[Pivot].pivotValues.head.asInstanceOf[Attribute].name == "\"C\"") val plan2 = Pivot(attr1, Seq.empty, Seq(attr3), child1) assert(plan2.aliasMap == map2) assert( - plan2.analyzed.asInstanceOf[Pivot].pivotColumn.asInstanceOf[Attribute].name == "\"COL1\"" - ) + plan2.analyzed.asInstanceOf[Pivot].pivotColumn.asInstanceOf[Attribute].name == "\"COL1\"") assert( - plan2.analyzed.asInstanceOf[Pivot].aggregates.head.asInstanceOf[Attribute].name == "\"C\"" - ) + plan2.analyzed.asInstanceOf[Pivot].aggregates.head.asInstanceOf[Attribute].name == "\"C\"") val plan3 = Pivot(attr3, Seq.empty, Seq.empty, child2) assert(plan3.aliasMap.isEmpty) assert( - plan3.analyzed.asInstanceOf[Pivot].pivotColumn.asInstanceOf[Attribute].name == "\"COL3\"" - ) + plan3.analyzed.asInstanceOf[Pivot].pivotColumn.asInstanceOf[Attribute].name == "\"COL3\"") } test("Filter - Analyzer") { @@ -811,8 +774,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .projectList .head .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan1 = Project(Seq(attr3), child2) assert(plan1.aliasMap.isEmpty) @@ -822,8 +784,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .projectList .head .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("ProjectAndFilter - Analyzer") { @@ -835,15 +796,13 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .projectList .head .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") assert( plan.analyzed .asInstanceOf[ProjectAndFilter] .condition .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan1 = ProjectAndFilter(Seq(attr3), attr3, child2) assert(plan1.aliasMap.isEmpty) @@ -853,15 +812,13 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .projectList .head .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") assert( plan1.analyzed .asInstanceOf[ProjectAndFilter] .condition .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("CreateViewCommand - Analyzer") { @@ -887,8 +844,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .args .head .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan2 = Lateral(child2, tf) assert(plan2.aliasMap.isEmpty) @@ -900,8 +856,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .args .head .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("Limit - Analyzer") { @@ -912,8 +867,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .asInstanceOf[Limit] .limitExpr .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan1 = Limit(attr3, child2) assert(plan1.aliasMap.isEmpty) @@ -922,8 +876,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .asInstanceOf[Limit] .limitExpr .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("TableFunctionJoin - Analyzer") { @@ -939,8 +892,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .args .head .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan2 = TableFunctionJoin(child2, tf, None) assert(plan2.aliasMap.isEmpty) @@ -952,8 +904,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .args .head .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("TableMerge - Analyzer") { @@ -965,8 +916,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .asInstanceOf[TableMerge] .joinExpr .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") assert( plan1.analyzed .asInstanceOf[TableMerge] @@ -976,8 +926,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan2 = TableMerge("dummy", child2, attr3, Seq(me)) assert(plan2.aliasMap.isEmpty) @@ -986,8 +935,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .asInstanceOf[TableMerge] .joinExpr .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") assert( plan2.analyzed .asInstanceOf[TableMerge] @@ -997,8 +945,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } test("SnowflakeCreateTable - Analyzer") { @@ -1020,8 +967,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") plan1.analyzed.asInstanceOf[TableUpdate].assignments.foreach { case (key: Attribute, value: Attribute) => assert(key.name == "\"C\"") @@ -1036,8 +982,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") plan2.analyzed.asInstanceOf[TableUpdate].assignments.foreach { case (key: Attribute, value: Attribute) => assert(key.name == "\"COL3\"") @@ -1054,8 +999,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan2 = TableDelete("dummy", Some(attr3), None) assert(plan2.aliasMap.isEmpty) @@ -1065,8 +1009,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } def binaryNodeAnalyzerChecker(func: (LogicalPlan, LogicalPlan) => LogicalPlan): Unit = { @@ -1103,8 +1046,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"C\"" - ) + .name == "\"C\"") val plan1 = Join(child2, child3, LeftOuter, Some(attr3)) assert(plan1.aliasMap.isEmpty) @@ -1114,8 +1056,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { .condition .get .asInstanceOf[Attribute] - .name == "\"COL3\"" - ) + .name == "\"COL3\"") } // updateChildren, simplifier @@ -1125,13 +1066,13 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { plan match { // small change for future verification case Range(_, end, _) => Range(end, 1, 1) - case _ => plan + case _ => plan } val plan = func(testData) val newPlan = plan.updateChildren(testFunc) assert(newPlan.children.zipWithIndex.forall { case (Range(start, _, _), i) if start == i => true - case _ => false + case _ => false }) } def leafSimplifierChecker(plan: LogicalPlan): Unit = { diff --git a/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala b/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala index cdccb9b6..397a0a60 100644 --- a/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala +++ b/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala @@ -1,133 +1,10 @@ package com.snowflake.snowpark import org.scalatest.FunSuite -import com.snowflake.snowpark_test._ import java.io.ByteArrayOutputStream -// entry of all Java JUnit test suite -// modify this class each time when adding a new Java test suite. @JavaAPITest class JavaAPISuite extends FunSuite { - test("Java Session") { - TestRunner.run(classOf[JavaSessionSuite]) - } - - test("Java Session without stored proc support") { - TestRunner.run(classOf[JavaSessionNonStoredProcSuite]) - } - - test("Java Variant") { - TestRunner.run(classOf[JavaVariantSuite]) - } - - test("Java Geography") { - TestRunner.run(classOf[JavaGeographySuite]) - } - - test("Java Row") { - TestRunner.run(classOf[JavaRowSuite]) - } - - test("Java DataType") { - TestRunner.run(classOf[JavaDataTypesSuite]) - } - - test("Java Column") { - TestRunner.run(classOf[JavaColumnSuite]) - } - - test("Java Window") { - TestRunner.run(classOf[JavaWindowSuite]) - } - - test("Java Functions") { - TestRunner.run(classOf[JavaFunctionSuite]) - } - - test("Java UDF") { - TestRunner.run(classOf[JavaUDFSuite]) - } - - test("Java UDF without stored proc support") { - TestRunner.run(classOf[JavaUDFNonStoredProcSuite]) - } - - test("Java UDTF") { - TestRunner.run(classOf[JavaUDTFSuite]) - } - - test("Java UDTF without stored proc support") { - TestRunner.run(classOf[JavaUDTFNonStoredProcSuite]) - } - - test("Java DataFrame") { - TestRunner.run(classOf[JavaDataFrameSuite]) - } - - test("Java DataFrame which doesn't support stored proc") { - TestRunner.run(classOf[JavaDataFrameNonStoredProcSuite]) - } - - test("Java DataFrameWriter") { - TestRunner.run(classOf[JavaDataFrameWriterSuite]) - } - - test("Java DataFrameReader") { - TestRunner.run(classOf[JavaDataFrameReaderSuite]) - } - - test("Java DataFrame Aggregate") { - TestRunner.run(classOf[JavaDataFrameAggregateSuite]) - } - - test("Java Internal Utils") { - TestRunner.run(classOf[JavaUtilsSuite]) - } - - test("Java Updatable") { - TestRunner.run(classOf[JavaUpdatableSuite]) - } - - test("Java DataFrameNaFunctions") { - TestRunner.run(classOf[JavaDataFrameNaFunctionsSuite]) - } - - test("Java DataFrameStatFunctions") { - TestRunner.run(classOf[JavaDataFrameStatFunctionsSuite]) - } - - test("Java TableFunction") { - TestRunner.run(classOf[JavaTableFunctionSuite]) - } - - test("Java CopyableDataFrame") { - TestRunner.run(classOf[JavaCopyableDataFrameSuite]) - } - - test("Java AsyncJob") { - TestRunner.run(classOf[JavaAsyncJobSuite]) - } - - test("Java FileOperation") { - TestRunner.run(classOf[JavaFileOperationSuite]) - } - - test("Java StoredProcedure") { - TestRunner.run(classOf[JavaStoredProcedureSuite]) - } - - test("Java SProc without stored proc support") { - TestRunner.run(classOf[JavaSProcNonStoredProcSuite]) - } - - test("Java OpenTelemetry") { - TestRunner.run(classOf[JavaOpenTelemetrySuite]) - } - - test("Java UDX OpenTelemetry") { - TestRunner.run(classOf[JavaUDXOpenTelemetrySuite]) - } - // some tests can't be implemented in Java are listed below // console redirect doesn't work in Java since run JUnit from Scala @@ -141,6 +18,7 @@ class JavaAPISuite extends FunSuite { Console.withOut(outContent) { df.explain() } + println(df.explain()) val result = outContent.toString("UTF-8") assert(result.contains("Query List:")) assert(result.contains("select\n" + " *\n" + "from\n" + "values(1, 2),(3, 4) as t(a, b)")) diff --git a/src/test/scala/com/snowflake/snowpark/MethodChainSuite.scala b/src/test/scala/com/snowflake/snowpark/MethodChainSuite.scala index 8b462551..480f9c1a 100644 --- a/src/test/scala/com/snowflake/snowpark/MethodChainSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/MethodChainSuite.scala @@ -30,8 +30,7 @@ class MethodChainSuite extends TestData { .toDF(Array("a3", "b3", "c3")), "toDF", "toDF", - "toDF" - ) + "toDF") } test("sort") { @@ -42,8 +41,7 @@ class MethodChainSuite extends TestData { .sort(Array(col("a"))), "sort", "sort", - "sort" - ) + "sort") } test("alias") { @@ -64,8 +62,7 @@ class MethodChainSuite extends TestData { "select", "select", "select", - "select" - ) + "select") } test("drop") { @@ -112,14 +109,12 @@ class MethodChainSuite extends TestData { test("groupByGroupingSets") { checkMethodChain( df1.groupByGroupingSets(GroupingSets(Set(col("a")))).count(), - "groupByGroupingSets.count" - ) + "groupByGroupingSets.count") checkMethodChain( df1 .groupByGroupingSets(Seq(GroupingSets(Set(col("a"))))) .builtin("count")(col("a")), - "groupByGroupingSets.builtin" - ) + "groupByGroupingSets.builtin") } test("cube") { @@ -203,20 +198,17 @@ class MethodChainSuite extends TestData { df.join(tf, Seq(df("b")), Seq(df("a")), Seq(df("b"))), "select", "toDF", - "join" - ) + "join") checkMethodChain( df.join(tf, Map("arg1" -> df("b")), Seq(df("a")), Seq(df("b"))), "select", "toDF", - "join" - ) + "join") checkMethodChain( df.join(tf(Map("arg1" -> df("b"))), Seq(df("a")), Seq(df("b"))), "select", "toDF", - "join" - ) + "join") val df3 = session.sql("select * from values('[1,2,3]') as T(a)") checkMethodChain(df3.join(flatten, Map("input" -> parse_json(df("a")))), "join") @@ -249,8 +241,7 @@ class MethodChainSuite extends TestData { checkMethodChain( nullData3.na.fill(Map("flo" -> 12.3, "int" -> 11, "boo" -> false, "str" -> "f")), "na", - "fill" - ) + "fill") checkMethodChain(nullData3.na.replace("flo", Map(2 -> 300, 1 -> 200)), "na", "replace") } @@ -264,7 +255,6 @@ class MethodChainSuite extends TestData { checkMethodChain(table1.flatten(table1("value")), "flatten") checkMethodChain( table1.flatten(table1("value"), "", outer = false, recursive = false, "both"), - "flatten" - ) + "flatten") } } diff --git a/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala b/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala index 08a199e2..3281a32c 100644 --- a/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala @@ -33,8 +33,7 @@ class NewColumnReferenceSuite extends SNTestBase { | |--B: Long (nullable = false) | |--B: Long (nullable = false) | |--C: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) } test("show", JavaStoredProcExclude) { @@ -45,8 +44,7 @@ class NewColumnReferenceSuite extends SNTestBase { |------------------------- ||1 |2 |2 |3 | |------------------------- - |""".stripMargin - ) + |""".stripMargin) assert(!df1_disabled.join(df2_disabled).showString(10).contains(""""B"""")) } @@ -65,14 +63,11 @@ class NewColumnReferenceSuite extends SNTestBase { "b", TestInternalAlias("c"), TestInternalAlias("d"), - "e" - ) - ) + "e")) val df5 = df4.drop(df2("c")) verifyOutputName( df5.output, - Seq("a", "c", TestInternalAlias("d"), "b", TestInternalAlias("d"), "e") - ) + Seq("a", "c", TestInternalAlias("d"), "b", TestInternalAlias("d"), "e")) val df1_disabled1 = disabledHideInternalAliasSession .createDataFrame(Seq((1, 2, 3, 4))) .toDF("a", "b", "c", "d") @@ -90,9 +85,7 @@ class NewColumnReferenceSuite extends SNTestBase { TestInternalAlias("b"), TestInternalAlias("c"), TestInternalAlias("d"), - "e" - ) - ) + "e")) val df8 = df7.drop(df2_disabled1("c")) verifyOutputName( df8.output, @@ -102,9 +95,7 @@ class NewColumnReferenceSuite extends SNTestBase { TestInternalAlias("d"), TestInternalAlias("b"), TestInternalAlias("d"), - "e" - ) - ) + "e")) } test("dedup - select", JavaStoredProcExclude) { @@ -114,8 +105,7 @@ class NewColumnReferenceSuite extends SNTestBase { val df4 = df3.select(df1("a"), df1("b"), df1("d"), df2("c"), df2("d"), df2("e")) verifyOutputName( df4.output, - Seq("a", "b", TestInternalAlias("d"), "c", TestInternalAlias("d"), "e") - ) + Seq("a", "b", TestInternalAlias("d"), "c", TestInternalAlias("d"), "e")) assert( df4.showString(10) == """------------------------------------- @@ -123,8 +113,7 @@ class NewColumnReferenceSuite extends SNTestBase { |------------------------------------- ||1 |2 |4 |3 |4 |5 | |------------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( df4.schema.treeString(0) == """root @@ -134,8 +123,7 @@ class NewColumnReferenceSuite extends SNTestBase { | |--C: Long (nullable = false) | |--D: Long (nullable = false) | |--E: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) val df1_disabled1 = disabledHideInternalAliasSession .createDataFrame(Seq((1, 2, 3, 4))) .toDF("a", "b", "c", "d") @@ -149,8 +137,7 @@ class NewColumnReferenceSuite extends SNTestBase { df1_disabled1("d"), df2_disabled1("c"), df2_disabled1("d"), - df2_disabled1("e") - ) + df2_disabled1("e")) verifyOutputName( df6.output, Seq( @@ -159,9 +146,7 @@ class NewColumnReferenceSuite extends SNTestBase { TestInternalAlias("d"), TestInternalAlias("c"), TestInternalAlias("d"), - "e" - ) - ) + "e")) val showString = df6.showString(10) assert(!showString.contains(""""B"""")) assert(!showString.contains(""""C"""")) @@ -187,8 +172,7 @@ class NewColumnReferenceSuite extends SNTestBase { |------------------------------------- ||1 |2 |2 |3 |5 |10 | |------------------------------------- - |""".stripMargin - ) + |""".stripMargin) val df1_disabled1 = disabledHideInternalAliasSession .createDataFrame(Seq((1, 2, 3, 4))) .toDF("a", "b", "c", "d") @@ -203,12 +187,10 @@ class NewColumnReferenceSuite extends SNTestBase { df2_disabled1("b").as("f"), df2_disabled1("c"), df2_disabled1("e"), - lit(10).as("c") - ) + lit(10).as("c")) verifyOutputName( df6.output, - Seq("a", TestInternalAlias("b"), "f", TestInternalAlias("c"), "e", "c") - ) + Seq("a", TestInternalAlias("b"), "f", TestInternalAlias("c"), "e", "c")) val showString = df6.showString(10) assert(!showString.contains(""""B"""")) assert(showString.contains(""""C"""")) @@ -323,8 +305,7 @@ class NewColumnReferenceSuite extends SNTestBase { verifyNode(_ => func, Seq.empty) private def verifyNode( func: Seq[LogicalPlan] => LogicalPlan, - children: Seq[LogicalPlan] - ): Unit = { + children: Seq[LogicalPlan]): Unit = { val plan = func(children) val expected = children .map(_.internalRenamedColumns) diff --git a/src/test/scala/com/snowflake/snowpark/OpenTelemetryEnabled.scala b/src/test/scala/com/snowflake/snowpark/OpenTelemetryEnabled.scala index 6684d443..902f8fca 100644 --- a/src/test/scala/com/snowflake/snowpark/OpenTelemetryEnabled.scala +++ b/src/test/scala/com/snowflake/snowpark/OpenTelemetryEnabled.scala @@ -39,8 +39,7 @@ trait OpenTelemetryEnabled extends TestData { funcName: String, fileName: String, lineNumber: Int, - methodChain: String - ): Unit = + methodChain: String): Unit = checkSpan(className, funcName) { span => { assert(span.getTotalAttributeCount == 3) @@ -57,8 +56,7 @@ trait OpenTelemetryEnabled extends TestData { lineNumber: Int, execName: String, execHandler: String, - execFilePath: String - ): Unit = + execFilePath: String): Unit = checkSpan(className, funcName) { span => { assert(span.getTotalAttributeCount == 5) @@ -67,12 +65,10 @@ trait OpenTelemetryEnabled extends TestData { assert(span.getAttributes.get(AttributeKey.stringKey("snow.executable.name")) == execName) assert( span.getAttributes - .get(AttributeKey.stringKey("snow.executable.handler")) == execHandler - ) + .get(AttributeKey.stringKey("snow.executable.handler")) == execHandler) assert( span.getAttributes - .get(AttributeKey.stringKey("snow.executable.filepath")) == execFilePath - ) + .get(AttributeKey.stringKey("snow.executable.filepath")) == execFilePath) } } diff --git a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala index fe8d7c67..75365e5b 100644 --- a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala @@ -32,8 +32,7 @@ class ParameterSuite extends SNTestBase { assert( sessionWithApplicationName.conn.connection.getSFBaseSession.getConnectionPropertiesMap - .get(SFSessionProperty.APPLICATION) == applicationName - ) + .get(SFSessionProperty.APPLICATION) == applicationName) } test("url") { diff --git a/src/test/scala/com/snowflake/snowpark/ReplSuite.scala b/src/test/scala/com/snowflake/snowpark/ReplSuite.scala index 9eb9218e..cba06aa4 100644 --- a/src/test/scala/com/snowflake/snowpark/ReplSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ReplSuite.scala @@ -75,12 +75,10 @@ class ReplSuite extends TestData { Files.copy( Paths.get(defaultProfile), Paths.get(s"$workDir/$defaultProfile"), - StandardCopyOption.REPLACE_EXISTING - ) + StandardCopyOption.REPLACE_EXISTING) Files.write( Paths.get(s"$workDir/file.txt"), - (preLoad + code + "sys.exit\n").getBytes(StandardCharsets.UTF_8) - ) + (preLoad + code + "sys.exit\n").getBytes(StandardCharsets.UTF_8)) s"cat $workDir/file.txt ".#|(s"$workDir/run.sh").!!.replaceAll("scala> ", "") } diff --git a/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala b/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala index 61b12477..37726640 100644 --- a/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala @@ -24,8 +24,7 @@ class ResultAttributesSuite extends SNTestBase { .map { case (tpe, index) => s"col_$index $tpe" } - .mkString(",") - ) + .mkString(",")) attribute = getTableAttributes(tableName) } finally { dropTable(name) @@ -80,8 +79,7 @@ class ResultAttributesSuite extends SNTestBase { "timestamp" -> TimestampType, "timestamp_ltz" -> TimestampType, "timestamp_ntz" -> TimestampType, - "timestamp_tz" -> TimestampType - ) + "timestamp_tz" -> TimestampType) val attribute = getAttributesWithTypes(tableName, dates.map(_._1)) assert(attribute.length == dates.length) dates.indices.foreach(index => assert(attribute(index).dataType == dates(index)._2)) @@ -94,12 +92,10 @@ class ResultAttributesSuite extends SNTestBase { assert( attribute(0).dataType == - VariantType - ) + VariantType) assert( attribute(1).dataType == - MapType(StringType, StringType) - ) + MapType(StringType, StringType)) } test("Array Type") { @@ -109,8 +105,6 @@ class ResultAttributesSuite extends SNTestBase { variants.indices.foreach(index => assert( attribute(index).dataType == - ArrayType(StringType) - ) - ) + ArrayType(StringType))) } } diff --git a/src/test/scala/com/snowflake/snowpark/SFTestUtils.scala b/src/test/scala/com/snowflake/snowpark/SFTestUtils.scala index f598477f..25cd9450 100644 --- a/src/test/scala/com/snowflake/snowpark/SFTestUtils.scala +++ b/src/test/scala/com/snowflake/snowpark/SFTestUtils.scala @@ -35,8 +35,7 @@ trait SFTestUtils { TestUtils.insertIntoTable(name, data, session) def uploadFileToStage(stageName: String, fileName: String, compress: Boolean)(implicit - session: Session - ): Unit = + session: Session): Unit = TestUtils.uploadFileToStage(stageName, fileName, compress, session) def verifySchema(sql: String, expectedSchema: StructType)(implicit session: Session): Unit = diff --git a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala index c4fcc55a..e1245569 100644 --- a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala +++ b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala @@ -80,8 +80,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S TypeMap("object", "object", Types.VARCHAR, MapType(StringType, StringType)), TypeMap("array", "array", Types.VARCHAR, ArrayType(StringType)), TypeMap("geography", "geography", Types.VARCHAR, GeographyType), - TypeMap("geometry", "geometry", Types.VARCHAR, GeometryType) - ) + TypeMap("geometry", "geometry", Types.VARCHAR, GeometryType)) implicit lazy val session: Session = { TestUtils.tryToLoadFipsProvider() @@ -97,7 +96,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S def equalsIgnoreCase(a: Option[String], b: Option[String]): Boolean = { (a, b) match { case (Some(l), Some(r)) => l.equalsIgnoreCase(r) - case _ => a == b + case _ => a == b } } @@ -171,8 +170,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S thunk: => T, parameter: String, value: String, - skipIfParamNotExist: Boolean = false - ): Unit = { + skipIfParamNotExist: Boolean = false): Unit = { var parameterNotExist = false try { session.runQuery(s"alter session set $parameter = $value") @@ -212,8 +210,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S columnTypeName: String, precision: Int, scale: Int, - signed: Boolean - ): DataType = + signed: Boolean): DataType = ServerConnection.getDataType(sqlType, columnTypeName, precision, scale, signed) def loadConfFromFile(path: String): Map[String, String] = Session.loadConfFromFile(path) @@ -227,8 +224,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S s"select query_text from " + s"table(information_schema.QUERY_HISTORY_BY_SESSION()) " + s"where query_tag ='$tag'", - sess - ) + sess) val result = statement.getResultSet val resArray = new ArrayBuffer[String]() while (result.next()) { @@ -243,8 +239,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S s"select query_tag from " + s"table(information_schema.QUERY_HISTORY_BY_SESSION()) " + s"where query_text ilike '%$queryText%'", - session - ) + session) val result = statement.getResultSet val resArray = new ArrayBuffer[String]() @@ -290,8 +285,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S def withSessionParameters( params: Seq[(String, String)], currentSession: Session, - skipPreprod: Boolean = false - )(thunk: => Unit): Unit = { + skipPreprod: Boolean = false)(thunk: => Unit): Unit = { if (!(skipPreprod && isPreprodAccount)) { try { params.foreach { case (paramName, value) => @@ -313,11 +307,9 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S ("IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE", "true"), ("FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), ("ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT", "true"), - ("ENABLE_STRUCTURED_TYPES_IN_BINDS", "enable") - ), + ("ENABLE_STRUCTURED_TYPES_IN_BINDS", "enable")), currentSession, - skipPreprod = true - )(thunk) + skipPreprod = true)(thunk) // disable these tests on preprod daily tests until these parameters are enabled by default. } } diff --git a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala index 7b9d7d64..4b69467a 100644 --- a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala @@ -50,8 +50,7 @@ class ServerConnectionSuite extends SNTestBase { } assert( ex1.errorCode.equals("0318") && - ex1.message.contains("The function call has been running for 19 seconds") - ) + ex1.message.contains("The function call has been running for 19 seconds")) } finally { session.cancelAll() statement.close() @@ -90,8 +89,7 @@ class ServerConnectionSuite extends SNTestBase { s"create or replace temporary table $tableName (c1 int, c2 string)", s"insert into $tableName values (1, 'abc'), (123, 'dfdffdfdf')", "select SYSTEM$WAIT(2)", - s"select max(c1) from $tableName" - ).map(Query(_)) + s"select max(c1) from $tableName").map(Query(_)) val attrs = Seq(Attribute("C1", LongType)) plan = new SnowflakePlan( queries, @@ -99,8 +97,7 @@ class ServerConnectionSuite extends SNTestBase { Seq(Query(s"drop table if exists $tableName", true)), session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) asyncJob = session.conn.executeAsync(plan) iterator = session.conn.getAsyncResult(asyncJob.getQueryId(), Int.MaxValue, Some(plan))._1 rows = iterator.toSeq @@ -123,16 +120,14 @@ class ServerConnectionSuite extends SNTestBase { s"create or replace temporary table $tableName (c1 int, c2 string)", s"insert into $tableName values (1, 'abc'), (123, 'dfdffdfdf')", "select SYSTEM$WAIT(2)", - s"select to_number('not_a_number') as C1" - ).map(Query(_)) + s"select to_number('not_a_number') as C1").map(Query(_)) plan = new SnowflakePlan( queries2, schemaValueStatement(attrs), Seq(Query(s"drop table if exists $tableName", true)), session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) asyncJob = session.conn.executeAsync(plan) val ex2 = intercept[SnowflakeSQLException] { session.conn.getAsyncResult(asyncJob.getQueryId(), Int.MaxValue, Some(plan))._1 @@ -148,9 +143,7 @@ class ServerConnectionSuite extends SNTestBase { val parameters2 = session.conn.getStatementParameters(true) assert( parameters2.size == 2 && parameters2.contains("QUERY_TAG") && parameters2.contains( - "SNOWPARK_SKIP_TXN_COMMIT_IN_DDL" - ) - ) + "SNOWPARK_SKIP_TXN_COMMIT_IN_DDL")) try { session.setQueryTag("test_tag_123") diff --git a/src/test/scala/com/snowflake/snowpark/SimplifierSuite.scala b/src/test/scala/com/snowflake/snowpark/SimplifierSuite.scala index 600aebaf..f0c90606 100644 --- a/src/test/scala/com/snowflake/snowpark/SimplifierSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/SimplifierSuite.scala @@ -72,9 +72,7 @@ class SimplifierSuite extends TestData { query.contains( "WHERE ((\"A\" < 5 :: int) AND " + "((\"A\" > 1 :: int) AND ((\"B\" <> 10 :: int) AND " + - "(\"C\" = 100 :: int))))" - ) - ) + "(\"C\" = 100 :: int))))")) val result1 = df .filter((df("a") === 2) or (df("a") === 0)) @@ -83,11 +81,8 @@ class SimplifierSuite extends TestData { checkAnswer(result1, Seq(Row(2, 10, 100))) val query1 = result1.snowflakePlan.queries.last.sql assert( - query1.contains( - "WHERE (((\"A\" = 2 :: int) OR (\"A\" = 0 :: int))" + - " AND ((\"B\" = 10 :: int) OR (\"B\" = 20 :: int)))" - ) - ) + query1.contains("WHERE (((\"A\" = 2 :: int) OR (\"A\" = 0 :: int))" + + " AND ((\"B\" = 10 :: int) OR (\"B\" = 20 :: int)))")) assert(query1.split("WHERE").length == 2) } @@ -141,8 +136,7 @@ class SimplifierSuite extends TestData { checkAnswer( newDf, - Seq(Row(1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29)) - ) + Seq(Row(1, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29))) } test("withColumns 2") { diff --git a/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala b/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala index 1626fb24..a37ef963 100644 --- a/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala @@ -43,8 +43,7 @@ class SnowflakePlanSuite extends SNTestBase { val queries = Seq( s"create or replace temporary table $tableName1 as select * from " + "values(1::INT, 'a'::STRING),(2::INT, 'b'::STRING) as T(A,B)", - s"select * from $tableName1" - ).map(Query(_)) + s"select * from $tableName1").map(Query(_)) val attrs = Seq(Attribute("A", IntegerType, nullable = true), Attribute("B", StringType, nullable = true)) @@ -55,8 +54,7 @@ class SnowflakePlanSuite extends SNTestBase { Seq.empty, session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) val plan1 = session.plans.project(Seq("A"), plan, None) assert(plan1.attributes.length == 1) @@ -75,8 +73,7 @@ class SnowflakePlanSuite extends SNTestBase { val queries2 = Seq( s"create or replace temporary table $tableName2 as select * from " + "values(3::INT),(4::INT) as T(A)", - s"select * from $tableName2" - ).map(Query(_)) + s"select * from $tableName2").map(Query(_)) val attrs2 = Seq(Attribute("C", LongType)) val plan2 = @@ -86,8 +83,7 @@ class SnowflakePlanSuite extends SNTestBase { Seq.empty, session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) val plan3 = session.plans.setOperator(plan1, plan2, "UNION ALL", None) assert(plan3.attributes.length == 1) @@ -102,8 +98,13 @@ class SnowflakePlanSuite extends SNTestBase { test("empty schema query") { assertThrows[SnowflakeSQLException]( - new SnowflakePlan(Seq.empty, "", Seq.empty, session, None, supportAsyncMode = true).attributes - ) + new SnowflakePlan( + Seq.empty, + "", + Seq.empty, + session, + None, + supportAsyncMode = true).attributes) } test("test SnowflakePlan.supportAsyncMode()") { @@ -184,8 +185,7 @@ class SnowflakePlanSuite extends SNTestBase { Seq.empty, session, None, - supportAsyncMode = true - ) + supportAsyncMode = true) val df = DataFrame(session, plan) var queryTag = Utils.getUserCodeMeta() diff --git a/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala b/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala index 8a14bd26..042a8146 100644 --- a/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala @@ -13,8 +13,7 @@ class SnowparkSFConnectionHandlerSuite extends FunSuite { test("version negative") { val err = intercept[SnowparkClientException]( - SnowparkSFConnectionHandler.extractValidVersionNumber("0.1") - ) + SnowparkSFConnectionHandler.extractValidVersionNumber("0.1")) assert(err.message.contains("Invalid client version string")) } diff --git a/src/test/scala/com/snowflake/snowpark/StagedFileReaderSuite.scala b/src/test/scala/com/snowflake/snowpark/StagedFileReaderSuite.scala index 7f1850ee..ce045669 100644 --- a/src/test/scala/com/snowflake/snowpark/StagedFileReaderSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/StagedFileReaderSuite.scala @@ -6,8 +6,7 @@ import com.snowflake.snowpark.types._ class StagedFileReaderSuite extends SNTestBase { private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) test("File Format Type") { val fileReadOrCopyPlanBuilder = new StagedFileReader(session) @@ -40,8 +39,7 @@ class StagedFileReaderSuite extends SNTestBase { "Integer" -> java.lang.Integer.valueOf("1"), "true" -> "True", "false" -> "false", - "string" -> "string" - ) + "string" -> "string") val savedOptions = fileReadOrCopyPlanBuilder.options(configs).curOptions assert(savedOptions.size == 6) assert(savedOptions("BOOLEAN").equals("true")) @@ -67,8 +65,7 @@ class StagedFileReaderSuite extends SNTestBase { val crt = plan.queries.head.sql assert( TestUtils - .containIgnoreCaseAndWhiteSpaces(crt, s"create table test_table_name if not exists") - ) + .containIgnoreCaseAndWhiteSpaces(crt, s"create table test_table_name if not exists")) val copy = plan.queries.last.sql assert(TestUtils.containIgnoreCaseAndWhiteSpaces(copy, s"copy into test_table_name")) assert(TestUtils.containIgnoreCaseAndWhiteSpaces(copy, "skip_header = 10")) diff --git a/src/test/scala/com/snowflake/snowpark/TestData.scala b/src/test/scala/com/snowflake/snowpark/TestData.scala index e1a34fb2..3a9f1caf 100644 --- a/src/test/scala/com/snowflake/snowpark/TestData.scala +++ b/src/test/scala/com/snowflake/snowpark/TestData.scala @@ -8,8 +8,7 @@ trait TestData extends SNTestBase { lazy val testData2: DataFrame = session.createDataFrame( - Seq(Data2(1, 1), Data2(1, 2), Data2(2, 1), Data2(2, 2), Data2(3, 1), Data2(3, 2)) - ) + Seq(Data2(1, 1), Data2(1, 2), Data2(2, 1), Data2(2, 2), Data2(3, 1), Data2(3, 2))) lazy val testData3: DataFrame = session.createDataFrame(Seq(Data3(1, None), Data3(2, Some(2)))) @@ -22,8 +21,7 @@ trait TestData extends SNTestBase { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil - ) + LowerCaseData(4, "d") :: Nil) lazy val upperCaseData: DataFrame = session.createDataFrame( @@ -32,24 +30,21 @@ trait TestData extends SNTestBase { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil - ) + UpperCaseData(6, "F") :: Nil) lazy val nullInts: DataFrame = session.createDataFrame( NullInts(1) :: NullInts(2) :: NullInts(3) :: - NullInts(null) :: Nil - ) + NullInts(null) :: Nil) lazy val allNulls: DataFrame = session.createDataFrame( NullInts(null) :: NullInts(null) :: NullInts(null) :: - NullInts(null) :: Nil - ) + NullInts(null) :: Nil) lazy val nullData1: DataFrame = session.sql("select * from values(null),(2),(1),(3),(null) as T(a)") @@ -57,15 +52,13 @@ trait TestData extends SNTestBase { lazy val nullData2: DataFrame = session.sql( "select * from values(1,2,3),(null,2,3),(null,null,3),(null,null,null)," + - "(1,null,3),(1,null,null),(1,2,null) as T(a,b,c)" - ) + "(1,null,3),(1,null,null),(1,2,null) as T(a,b,c)") lazy val nullData3: DataFrame = session.sql( "select * from values(1.0, 1, true, 'a'),('NaN'::Double, 2, null, 'b')," + "(null, 3, false, null), (4.0, null, null, 'd'), (null, null, null, null), " + - "('NaN'::Double, null, null, null) as T(flo, int, boo, str)" - ) + "('NaN'::Double, null, null, null) as T(flo, int, boo, str)") lazy val integer1: DataFrame = session.sql("select * from values(1),(2),(3) as T(a)") @@ -78,8 +71,7 @@ trait TestData extends SNTestBase { lazy val double3: DataFrame = session.sql( "select * from values(1.0, 1),('NaN'::Double, 2),(null, 3)," + - " (4.0, null), (null, null), ('NaN'::Double, null) as T(a, b)" - ) + " (4.0, null), (null, null), ('NaN'::Double, null) as T(a, b)") lazy val double4: DataFrame = session.sql("select * from values(1.0, 1) as T(a, b)") @@ -96,8 +88,7 @@ trait TestData extends SNTestBase { lazy val approxNumbers2: DataFrame = session.sql( "select * from values(1, 1),(2, 1),(3, 3),(4, 3),(5, 3),(6, 3),(7, 3)," + - "(8, 5),(9, 5),(0, 5) as T(a, T)" - ) + "(8, 5),(9, 5),(0, 5) as T(a, T)") lazy val string1: DataFrame = session.sql("select * from values('test1', 'a'),('test2', 'b'),('test3', 'c') as T(a, b)") @@ -123,39 +114,33 @@ trait TestData extends SNTestBase { lazy val array1: DataFrame = session.sql( "select array_construct(a,b,c) as arr1, array_construct(d,e,f) as arr2 " + - "from values(1,2,3,3,4,5),(6,7,8,9,0,1) as T(a,b,c,d,e,f)" - ) + "from values(1,2,3,3,4,5),(6,7,8,9,0,1) as T(a,b,c,d,e,f)") lazy val array2: DataFrame = session.sql( "select array_construct(a,b,c) as arr1, d, e, f " + - "from values(1,2,3,2,'e1','[{a:1}]'),(6,7,8,1,'e2','[{a:1},{b:2}]') as T(a,b,c,d,e,f)" - ) + "from values(1,2,3,2,'e1','[{a:1}]'),(6,7,8,1,'e2','[{a:1},{b:2}]') as T(a,b,c,d,e,f)") lazy val array3: DataFrame = session.sql( "select array_construct(a,b,c) as arr1, d, e, f " + - "from values(1,2,3,1,2,','),(4,5,6,1,-1,', '),(6,7,8,0,2,';') as T(a,b,c,d,e,f)" - ) + "from values(1,2,3,1,2,','),(4,5,6,1,-1,', '),(6,7,8,0,2,';') as T(a,b,c,d,e,f)") lazy val object1: DataFrame = session.sql( "select key, to_variant(value) as value from values('age', 21),('zip', " + - "94401) as T(key,value)" - ) + "94401) as T(key,value)") lazy val object2: DataFrame = session.sql( "select object_construct(a,b,c,d,e,f) as obj, k, v, flag from values('age', 21, 'zip', " + "21021, 'name', 'Joe', 'age', 0, true),('age', 26, 'zip', 94021, 'name', 'Jay', 'key', " + - "0, false) as T(a,b,c,d,e,f,k,v,flag)" - ) + "0, false) as T(a,b,c,d,e,f,k,v,flag)") lazy val nullArray1: DataFrame = session.sql( "select array_construct(a,b,c) as arr1, array_construct(d,e,f) as arr2 " + - "from values(1,null,3,3,null,5),(6,null,8,9,null,1) as T(a,b,c,d,e,f)" - ) + "from values(1,null,3,3,null,5),(6,null,8,9,null,1) as T(a,b,c,d,e,f)") lazy val variant1: DataFrame = session.sql( @@ -171,8 +156,7 @@ trait TestData extends SNTestBase { " to_variant(to_timestamp_tz('2017-02-24 13:00:00.123 +01:00')) as timestamp_tz1, " + " to_variant(1.23::decimal(6, 3)) as decimal1, " + " to_variant(3.21::double) as double1, " + - " to_variant(15) as num1 " - ) + " to_variant(15) as num1 ") lazy val variant2: DataFrame = session.sql(""" @@ -194,27 +178,23 @@ trait TestData extends SNTestBase { |""".stripMargin) lazy val nullJson1: DataFrame = session.sql( - "select parse_json(column1) as v from values ('{\"a\": null}'), ('{\"a\": \"foo\"}'), (null)" - ) + "select parse_json(column1) as v from values ('{\"a\": null}'), ('{\"a\": \"foo\"}'), (null)") lazy val validJson1: DataFrame = session.sql( "select parse_json(column1) as v, column2 as k from values ('{\"a\": null}','a'), " + - "('{\"a\": \"foo\"}','a'), ('{\"a\": \"foo\"}','b'), (null,'a')" - ) + "('{\"a\": \"foo\"}','a'), ('{\"a\": \"foo\"}','b'), (null,'a')") lazy val invalidJson1: DataFrame = session.sql("select (column1) as v from values ('{\"a\": null'), ('{\"a: \"foo\"}'), ('{\"a:')") lazy val nullXML1: DataFrame = session.sql( "select (column1) as v from values ('foobar'), " + - "(''), (null), ('')" - ) + "(''), (null), ('')") lazy val validXML1: DataFrame = session.sql( "select parse_xml(a) as v, b as t2, c as t3, d as instance from values" + "('foobar','t2','t3',0),('','t2','t3',0)," + - "('foobar','t2','t3',1) as T(a,b,c,d)" - ) + "('foobar','t2','t3',1) as T(a,b,c,d)") lazy val invalidXML1: DataFrame = session.sql("select (column1) as v from values (''), (''), ('')") @@ -224,8 +204,7 @@ trait TestData extends SNTestBase { lazy val date2: DataFrame = session.sql( - "select * from values('2020-08-01'::Date, 'mo'),('2010-12-01'::Date, 'we') as T(a,b)" - ) + "select * from values('2020-08-01'::Date, 'mo'),('2010-12-01'::Date, 'we') as T(a,b)") lazy val date3: DataFrame = Seq((2020, 10, 28, 13, 35, 47, 1234567, "America/Los_Angeles")).toDF( @@ -236,8 +215,7 @@ trait TestData extends SNTestBase { "minute", "second", "nanosecond", - "timezone" - ) + "timezone") lazy val number1: DataFrame = session.createDataFrame( Seq( @@ -245,9 +223,7 @@ trait TestData extends SNTestBase { Number1(2, 10.0, 11.0), Number1(2, 20.0, 22.0), Number1(2, 25.0, 0), - Number1(2, 30.0, 35.0) - ) - ) + Number1(2, 30.0, 35.0))) lazy val decimalData: DataFrame = session.createDataFrame( @@ -256,8 +232,7 @@ trait TestData extends SNTestBase { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil - ) + DecimalData(3, 2) :: Nil) lazy val number2: DataFrame = session.createDataFrame(Seq(Number2(1, 2, 3), Number2(0, -1, 4), Number2(-5, 0, -9))) @@ -268,18 +243,20 @@ trait TestData extends SNTestBase { lazy val timestamp1: DataFrame = session.sql( "select * from values('2020-05-01 13:11:20.000' :: timestamp)," + - "('2020-08-21 01:30:05.000' :: timestamp) as T(a)" - ) + "('2020-08-21 01:30:05.000' :: timestamp) as T(a)") lazy val timestampNTZ: DataFrame = session.sql( "select * from values('2020-05-01 13:11:20.000' :: timestamp_ntz)," + - "('2020-08-21 01:30:05.000' :: timestamp_ntz) as T(a)" - ) + "('2020-08-21 01:30:05.000' :: timestamp_ntz) as T(a)") lazy val xyz: DataFrame = session.createDataFrame( - Seq(Number2(1, 2, 1), Number2(1, 2, 3), Number2(2, 1, 10), Number2(2, 2, 1), Number2(2, 2, 3)) - ) + Seq( + Number2(1, 2, 1), + Number2(1, 2, 3), + Number2(2, 1, 10), + Number2(2, 2, 1), + Number2(2, 2, 3))) lazy val long1: DataFrame = session.sql("select * from values(1561479557),(1565479557),(1161479557) as T(a)") @@ -291,8 +268,7 @@ trait TestData extends SNTestBase { CourseSales("Java", 2012, 20000) :: CourseSales("dotNET", 2012, 5000) :: CourseSales("dotNET", 2013, 48000) :: - CourseSales("Java", 2013, 30000) :: Nil - ) + CourseSales("Java", 2013, 30000) :: Nil) lazy val monthlySales: DataFrame = session.createDataFrame( Seq( @@ -311,9 +287,7 @@ trait TestData extends SNTestBase { MonthlySales(1, 8000, "APR"), MonthlySales(1, 10000, "APR"), MonthlySales(2, 800, "APR"), - MonthlySales(2, 4500, "APR") - ) - ) + MonthlySales(2, 4500, "APR"))) lazy val columnNameHasSpecialCharacter: DataFrame = { Seq((1, 2), (3, 4)).toDF("\"col %\"", "\"col *\"") @@ -326,8 +300,7 @@ trait TestData extends SNTestBase { (471, "Andrea Renee Nouveau", "RN", "Amateur Extra"), (101, "Lily Vine", "LVN", null), (102, "Larry Vancouver", "LVN", null), - (172, "Rhonda Nova", "RN", null) - ).toDF("id", "full_name", "medical_license", "radio_license") + (172, "Rhonda Nova", "RN", null)).toDF("id", "full_name", "medical_license", "radio_license") } diff --git a/src/test/scala/com/snowflake/snowpark/TestUtils.scala b/src/test/scala/com/snowflake/snowpark/TestUtils.scala index 427f28c9..46b22227 100644 --- a/src/test/scala/com/snowflake/snowpark/TestUtils.scala +++ b/src/test/scala/com/snowflake/snowpark/TestUtils.scala @@ -106,8 +106,7 @@ object TestUtils extends Logging { stageName: String, fileName: String, compress: Boolean, - session: Session - ): Unit = { + session: Session): Unit = { val input = getClass.getResourceAsStream(s"/$fileName") session.conn.uploadStream(stageName, null, input, fileName, compress) } @@ -116,8 +115,7 @@ object TestUtils extends Logging { def addDepsToClassPath( sess: Session, stageName: Option[String], - usePackages: Boolean = false - ): Unit = { + usePackages: Boolean = false): Unit = { // Initialize the lazy val so the defaultURIs are added to session sess.udf sess.udtf @@ -146,8 +144,7 @@ object TestUtils extends Logging { classOf[BeforeAndAfterAll], // scala test jar classOf[org.scalactic.TripleEquals], // scalactic jar classOf[io.opentelemetry.exporters.inmemory.InMemorySpanExporter], - classOf[io.opentelemetry.sdk.trace.export.SpanExporter] - ) + classOf[io.opentelemetry.sdk.trace.export.SpanExporter]) .flatMap(UDFClassPath.getPathForClass(_)) .foreach(path => { val file = new File(path) @@ -176,8 +173,7 @@ object TestUtils extends Logging { classOf[BeforeAndAfterAll], // scala test jar classOf[org.scalactic.TripleEquals], // scalactic jar classOf[io.opentelemetry.exporters.inmemory.InMemorySpanExporter], - classOf[io.opentelemetry.sdk.trace.export.SpanExporter] - ) + classOf[io.opentelemetry.sdk.trace.export.SpanExporter]) .flatMap(UDFClassPath.getPathForClass(_)) .foreach(path => { val file = new File(path) @@ -203,23 +199,17 @@ object TestUtils extends Logging { (0 until columnCount).foreach(index => { assert( quoteNameWithoutUpperCasing(resultMeta.getColumnLabel(index + 1)) == expectedSchema( - index - ).columnIdentifier.quotedName - ) + index).columnIdentifier.quotedName) assert( (resultMeta.isNullable(index + 1) != ResultSetMetaData.columnNoNulls) == expectedSchema( - index - ).nullable - ) + index).nullable) assert( ServerConnection.getDataType( resultMeta.getColumnType(index + 1), resultMeta.getColumnTypeName(index + 1), resultMeta.getPrecision(index + 1), resultMeta.getScale(index + 1), - resultMeta.isSigned(index + 1) - ) == expectedSchema(index).dataType - ) + resultMeta.isSigned(index + 1)) == expectedSchema(index).dataType) }) statement.close() } @@ -237,8 +227,8 @@ object TestUtils extends Logging { def compare(obj1: Any, obj2: Any): Boolean = { val res = (obj1, obj2) match { case (null, null) => true - case (null, _) => false - case (_, null) => false + case (null, _) => false + case (_, null) => false case (a: Array[_], b: Array[_]) => a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r) } case (a: Map[_, _], b: Map[_, _]) => @@ -252,20 +242,20 @@ object TestUtils extends Logging { case (a: Row, b: Row) => compare(a.toSeq, b.toSeq) // Note this returns 0.0 and -0.0 as same - case (a: BigDecimal, b) => compare(a.bigDecimal, b) - case (a, b: BigDecimal) => compare(a, b.bigDecimal) - case (a: Float, b) => compare(a.toDouble, b) - case (a, b: Float) => compare(a, b.toDouble) + case (a: BigDecimal, b) => compare(a.bigDecimal, b) + case (a, b: BigDecimal) => compare(a, b.bigDecimal) + case (a: Float, b) => compare(a.toDouble, b) + case (a, b: Float) => compare(a, b.toDouble) case (a: Double, b: Double) if a.isNaN && b.isNaN => true - case (a: Double, b: Double) => (a - b).abs < 0.0001 - case (a: Double, b: java.math.BigDecimal) => (a - b.toString.toDouble).abs < 0.0001 - case (a: java.math.BigDecimal, b: Double) => (a.toString.toDouble - b).abs < 0.0001 + case (a: Double, b: Double) => (a - b).abs < 0.0001 + case (a: Double, b: java.math.BigDecimal) => (a - b.toString.toDouble).abs < 0.0001 + case (a: java.math.BigDecimal, b: Double) => (a.toString.toDouble - b).abs < 0.0001 case (a: java.math.BigDecimal, b: java.math.BigDecimal) => // BigDecimal(1.2) isn't equal to BigDecimal(1.200), so can't use function // equal to verify a.subtract(b).abs().compareTo(java.math.BigDecimal.valueOf(0.0001)) == -1 - case (a: Date, b: Date) => a.toString == b.toString - case (a: Time, b: Time) => a.toString == b.toString + case (a: Date, b: Date) => a.toString == b.toString + case (a: Time, b: Time) => a.toString == b.toString case (a: Timestamp, b: Timestamp) => a.toString == b.toString case (a: Geography, b: Geography) => // todo: improve Geography.equals method @@ -279,8 +269,7 @@ object TestUtils extends Logging { println( s"Find different elements: " + s"$obj1${if (obj1 != null) s":${obj1.getClass.getSimpleName}" else ""} != " + - s"$obj2${if (obj2 != null) s":${obj2.getClass.getSimpleName}" else ""}" - ) + s"$obj2${if (obj2 != null) s":${obj2.getClass.getSimpleName}" else ""}") // scalastyle:on println } res @@ -292,8 +281,7 @@ object TestUtils extends Logging { assert( compare(sorted, sortedExpected.toArray[Row]), s"${sorted.map(_.toString).mkString("[", ", ", "]")} != " + - s"${sortedExpected.map(_.toString).mkString("[", ", ", "]")}" - ) + s"${sortedExpected.map(_.toString).mkString("[", ", ", "]")}") } def checkResult(result: Array[Row], expected: java.util.List[Row], sort: Boolean): Unit = @@ -335,8 +323,7 @@ object TestUtils extends Logging { packageName: String, className: String, pathPrefix: String, - jarFileName: String - ): String = { + jarFileName: String): String = { val dummyCode = s""" | package $packageName; @@ -383,9 +370,9 @@ object TestUtils extends Logging { def tryToLoadFipsProvider(): Unit = { val isFipsTest = { System.getProperty("FIPS_TEST") match { - case null => false + case null => false case value if value.toLowerCase() == "true" => true - case _ => false + case _ => false } } diff --git a/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala index 60cd162f..3c3f388d 100644 --- a/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala @@ -76,8 +76,7 @@ class UDFClasspathSuite extends SNTestBase { randomStage, "", jarFileName, - Map.empty - ) + Map.empty) mockSession.addDependency("@" + randomStage + "/" + jarFileName) val func = "func_" + Random.nextInt().abs udfR.registerUDF(Some(func), _toUdf((a: Int) => a + a), None) @@ -90,8 +89,7 @@ class UDFClasspathSuite extends SNTestBase { test( "Test that snowpark jar is automatically added" + - " if there is classNotFound error in first attempt" - ) { + " if there is classNotFound error in first attempt") { val newSession = Session.builder.configFile(defaultProfile).create TestUtils.addDepsToClassPath(newSession) val udfR = spy(new UDXRegistrationHandler(newSession)) diff --git a/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala index 5d721627..da272207 100644 --- a/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala @@ -40,8 +40,7 @@ class UDFInternalSuite extends TestData { mockSession .createDataFrame(Seq(1, 2)) .select(doubleUDF(col("value"))), - Seq(Row(2), Row(4)) - ) + Seq(Row(2), Row(4))) if (mockSession.isVersionSupportedByServerPackages) { verify(mockSession, times(0)).addDependency(path) } else { @@ -98,9 +97,7 @@ class UDFInternalSuite extends TestData { case e: SQLException => assert( e.getMessage.contains( - s"SQL compilation error: Package '${Utils.clientPackageName}' is not supported" - ) - ) + s"SQL compilation error: Package '${Utils.clientPackageName}' is not supported")) case _ => fail("Unexpected error from server") } val path = UDFClassPath.getPathForClass(classOf[com.snowflake.snowpark.Session]).get @@ -125,22 +122,19 @@ class UDFInternalSuite extends TestData { session.udf.registerTemporary(quotedTempFuncName, udf) assert( - session.sql(s"ls ${session.getSessionStage}/$quotedTempFuncName/").collect().length == 0 - ) + session.sql(s"ls ${session.getSessionStage}/$quotedTempFuncName/").collect().length == 0) assert( session .sql(s"ls ${session.getSessionStage}/${Utils.getUDFUploadPrefix(quotedTempFuncName)}/") .collect() - .length == 2 - ) + .length == 2) session.udf.registerPermanent(quotedPermFuncName, udf, stageName) assert(session.sql(s"ls @$stageName/$quotedPermFuncName/").collect().length == 0) assert( session .sql(s"ls @$stageName/${Utils.getUDFUploadPrefix(quotedPermFuncName)}/") .collect() - .length == 2 - ) + .length == 2) } finally { session.runQuery(s"drop function if exists $tempFuncName(INT)") session.runQuery(s"drop function if exists $permFuncName(INT)") @@ -165,8 +159,7 @@ class UDFInternalSuite extends TestData { }, "snowpark_use_scoped_temp_objects", "true", - skipIfParamNotExist = true - ) + skipIfParamNotExist = true) } test("register UDF should not upload duplicated dependencies", JavaStoredProcExclude) { diff --git a/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala index 3b7d6927..1c0984b6 100644 --- a/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala @@ -33,8 +33,7 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils { tempStage, stagePrefix, jarFileName, - funcBytesMap - ) + funcBytesMap) val stageFile = "@" + tempStage + "/" + stagePrefix + "/" + jarFileName // Download file from stage session.runQuery(s"get $stageFile file://${TestUtils.tempDirWithEscape}") @@ -42,8 +41,7 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils { // Check that classes in directories in UDFClasspath are included assert( classesInJar.contains("com/snowflake/snowpark/Session.class") || session.packageNames - .contains(clientPackageName) - ) + .contains(clientPackageName)) // Check that classes in jars in UDFClasspath are NOT included assert(!classesInJar.contains("scala/Function1.class")) // Check that function class is included @@ -62,8 +60,7 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils { tempStage, stagePrefix, jarFileName, - funcBytesMap - ) + funcBytesMap) } assert(ex1.isInstanceOf[NoSuchFileException]) @@ -75,21 +72,18 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils { "not_exist_stage_name", stagePrefix, jarFileName, - funcBytesMap - ) + funcBytesMap) } assert( ex2.getMessage.contains("Stage") && - ex2.getMessage.contains("does not exist or not authorized.") - ) + ex2.getMessage.contains("does not exist or not authorized.")) } // Dynamic Compile scala code private def generateDynamicClass( packageName: String, className: String, - inMemory: Boolean - ): Class[_] = { + inMemory: Boolean): Class[_] = { // Generate a temp file for the scala code. val classContent = s"package $packageName\n class $className {\n class InnerClass {}\n}\nclass OuterClass {}\n" diff --git a/src/test/scala/com/snowflake/snowpark/UDTFInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/UDTFInternalSuite.scala index bbea73d7..2598955c 100644 --- a/src/test/scala/com/snowflake/snowpark/UDTFInternalSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UDTFInternalSuite.scala @@ -25,14 +25,12 @@ class UDTFInternalSuite extends SNTestBase { assert( udtfHandler .generateUDTFClassSignature(udtf0, udtf0.inputColumns) - .equals("com.snowflake.snowpark.udtf.UDTF0") - ) + .equals("com.snowflake.snowpark.udtf.UDTF0")) val udtf00 = new TestUDTF0() assert( udtfHandler .generateUDTFClassSignature(udtf00, udtf00.inputColumns) - .equals("com.snowflake.snowpark.udtf.UDTF0") - ) + .equals("com.snowflake.snowpark.udtf.UDTF0")) } test("Unit test for UDTF1") { @@ -51,14 +49,12 @@ class UDTFInternalSuite extends SNTestBase { assert( udtfHandler .generateUDTFClassSignature(udtf1, udtf1.inputColumns) - .equals("com.snowflake.snowpark.udtf.UDTF1") - ) + .equals("com.snowflake.snowpark.udtf.UDTF1")) val udtf11 = new TestUDTF1() assert( udtfHandler .generateUDTFClassSignature(udtf11, udtf11.inputColumns) - .equals("com.snowflake.snowpark.udtf.UDTF1") - ) + .equals("com.snowflake.snowpark.udtf.UDTF1")) } @@ -79,14 +75,12 @@ class UDTFInternalSuite extends SNTestBase { assert( udtfHandler .generateUDTFClassSignature(udtf2, udtf2.inputColumns) - .equals("com.snowflake.snowpark.udtf.UDTF2") - ) + .equals("com.snowflake.snowpark.udtf.UDTF2")) val udtf22 = new TestUDTF2() assert( udtfHandler .generateUDTFClassSignature(udtf22, udtf22.inputColumns) - .equals("com.snowflake.snowpark.udtf.UDTF2") - ) + .equals("com.snowflake.snowpark.udtf.UDTF2")) } test("negative test: Unsupported type is used") { diff --git a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala index 70e36829..b552a632 100644 --- a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala @@ -34,8 +34,7 @@ class UtilsSuite extends SNTestBase { "snowpark-0.3.0.jar", "snowpark-SNAPSHOT-0.4.0.jar", "SNOWPARK-0.2.1.jar", - "snowpark-0.3.0.jar.gz" - ).foreach(jarName => { + "snowpark-0.3.0.jar.gz").foreach(jarName => { assert(Utils.isSnowparkJar(jarName)) }) Seq("random.jar", "snow-0.3.0.jar", "snowpark", "snowpark.tar.gz").foreach(jarName => { @@ -124,8 +123,7 @@ class UtilsSuite extends SNTestBase { assert(TypeToSchemaConverter.inferSchema[BigDecimal]().head.dataType == DecimalType(34, 6)) assert( TypeToSchemaConverter.inferSchema[JavaBigDecimal]().head.dataType == - DecimalType(34, 6) - ) + DecimalType(34, 6)) assert(TypeToSchemaConverter.inferSchema[Date]().head.dataType == DateType) assert(TypeToSchemaConverter.inferSchema[Timestamp]().head.dataType == TimestampType) assert(TypeToSchemaConverter.inferSchema[Time]().head.dataType == TimeType) @@ -133,9 +131,7 @@ class UtilsSuite extends SNTestBase { assert( TypeToSchemaConverter.inferSchema[Map[String, Boolean]]().head.dataType == MapType( StringType, - BooleanType - ) - ) + BooleanType)) assert(TypeToSchemaConverter.inferSchema[Variant]().head.dataType == VariantType) assert(TypeToSchemaConverter.inferSchema[Geography]().head.dataType == GeographyType) assert(TypeToSchemaConverter.inferSchema[Geometry]().head.dataType == GeometryType) @@ -152,8 +148,7 @@ class UtilsSuite extends SNTestBase { | |--_4: Geography (nullable = true) | |--_5: Map (nullable = true) | |--_6: Geometry (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // case class assert( @@ -165,8 +160,7 @@ class UtilsSuite extends SNTestBase { | |--DOUBLE: Double (nullable = false) | |--VARIANT: Variant (nullable = true) | |--ARRAY: Array (nullable = true) - |""".stripMargin - ) + |""".stripMargin) } case class Table1(int: Int, double: Double, variant: Variant, array: Array[String]) @@ -183,8 +177,7 @@ class UtilsSuite extends SNTestBase { | |--LONG: Long (nullable = true) | |--FLOAT: Float (nullable = true) | |--DOUBLE: Double (nullable = true) - |""".stripMargin - ) + |""".stripMargin) } case class Table2( @@ -194,8 +187,7 @@ class UtilsSuite extends SNTestBase { int: Integer, long: java.lang.Long, float: java.lang.Float, - double: java.lang.Double - ) + double: java.lang.Double) test("Non-nullable types") { TypeToSchemaConverter @@ -214,28 +206,25 @@ class UtilsSuite extends SNTestBase { test("Nullable types") { TypeToSchemaConverter - .inferSchema[ - ( - Option[Int], - JavaBoolean, - JavaByte, - JavaShort, - JavaInteger, - JavaLong, - JavaFloat, - JavaDouble, - Array[Boolean], - Map[String, Double], - JavaBigDecimal, - BigDecimal, - Variant, - Geography, - Date, - Time, - Timestamp, - Geometry - ) - ]() + .inferSchema[( + Option[Int], + JavaBoolean, + JavaByte, + JavaShort, + JavaInteger, + JavaLong, + JavaFloat, + JavaDouble, + Array[Boolean], + Map[String, Double], + JavaBigDecimal, + BigDecimal, + Variant, + Geography, + Date, + Time, + Timestamp, + Geometry)]() .treeString(0) == """root | |--_1: Integer (nullable = true) @@ -380,12 +369,10 @@ class UtilsSuite extends SNTestBase { assert(Utils.parseStageFileLocation("@stage/file") == ("@stage", "", "file")) assert( Utils.parseStageFileLocation("@\"st\\age\"/path/file") - == ("@\"st\\age\"", "path", "file") - ) + == ("@\"st\\age\"", "path", "file")) assert( Utils.parseStageFileLocation("@\"\\db\".\"\\Schema\".\"\\stage\"/path/file") - == ("@\"\\db\".\"\\Schema\".\"\\stage\"", "path", "file") - ) + == ("@\"\\db\".\"\\Schema\".\"\\stage\"", "path", "file")) assert(Utils.parseStageFileLocation("@stage/////file") == ("@stage", "///", "file")) } @@ -420,8 +407,7 @@ class UtilsSuite extends SNTestBase { "\"\"\"na.me\"\"\"", "\"n\"\"a..m\"\"e\"", "\"schema\"\"\".\"n\"\"a..m\"\"e\"", - "\"\"\"db\".\"schema\"\"\".\"n\"\"a..m\"\"e\"" - ) + "\"\"\"db\".\"schema\"\"\".\"n\"\"a..m\"\"e\"") test("test Utils.validateObjectName()") { validIdentifiers.foreach { name => @@ -490,8 +476,7 @@ class UtilsSuite extends SNTestBase { "a\"\"b\".c.t", ".\"name..\"", "..\"name\"", - "\"\".\"name\"" - ) + "\"\".\"name\"") names.foreach { name => // println(s"negative test: $name") @@ -515,8 +500,7 @@ class UtilsSuite extends SNTestBase { ("com/snowflake/snowpark/", "com/snowflake/snowpark/"), ("com/snowflake/snowpark", "com/snowflake/snowpark"), ("d:", "d:"), - ("d:\\dir", "d:/dir") - ) + ("d:\\dir", "d:/dir")) testItems.foreach { item => assert(Utils.convertWindowsPathToLinux(item._1).equals(item._2)) @@ -629,20 +613,16 @@ class UtilsSuite extends SNTestBase { test("java utils javaSaveModeToScala") { assert( JavaUtils.javaSaveModeToScala(com.snowflake.snowpark_java.SaveMode.Append) - == SaveMode.Append - ) + == SaveMode.Append) assert( JavaUtils.javaSaveModeToScala(com.snowflake.snowpark_java.SaveMode.Ignore) - == SaveMode.Ignore - ) + == SaveMode.Ignore) assert( JavaUtils.javaSaveModeToScala(com.snowflake.snowpark_java.SaveMode.Overwrite) - == SaveMode.Overwrite - ) + == SaveMode.Overwrite) assert( JavaUtils.javaSaveModeToScala(com.snowflake.snowpark_java.SaveMode.ErrorIfExists) - == SaveMode.ErrorIfExists - ) + == SaveMode.ErrorIfExists) } test("isValidateJavaIdentifier()") { diff --git a/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala b/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala index 3d3f02a3..b7d70507 100644 --- a/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala @@ -18,8 +18,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { val tableName1: String = randomName() val tableName2: String = randomName() private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) // session to verify permanent udf lazy private val newSession = Session.builder.configFile(defaultProfile).create @@ -128,11 +127,8 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { } assert(ex1.errorCode.equals("0318")) assert( - ex1.getMessage.matches( - ".*The query with the ID .* is still running and " + - "has the current status RUNNING. The function call has been running for 2 seconds,.*" - ) - ) + ex1.getMessage.matches(".*The query with the ID .* is still running and " + + "has the current status RUNNING. The function call has been running for 2 seconds,.*")) // getIterator() raises exception val ex2 = intercept[SnowparkClientException] { @@ -353,8 +349,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { // This function is copied from DataFrameReader.testReadFile def testReadFile(testName: String, testTags: Tag*)( - thunk: (() => DataFrameReader) => Unit - ): Unit = { + thunk: (() => DataFrameReader) => Unit): Unit = { // test select test(testName + " - SELECT", testTags: _*) { thunk(() => session.read) @@ -384,9 +379,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Seq( StructField("a", IntegerType), StructField("b", IntegerType), - StructField("c", IntegerType) - ) - ) + StructField("c", IntegerType))) val df2 = reader().schema(incorrectSchema).csv(testFileOnStage) assertThrows[SnowflakeSQLException](df2.async.collect().getResult()) }) @@ -405,22 +398,18 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { val df = session.createDataFrame(Seq(1, 2)).toDF(Seq("a")) checkResult( df.select(callUDF(tempFuncName, df("a"))).async.collect().getResult(), - Seq(Row(2), Row(3)) - ) + Seq(Row(2), Row(3))) checkResult( df.select(callUDF(permFuncName, df("a"))).async.collect().getResult(), - Seq(Row(2), Row(3)) - ) + Seq(Row(2), Row(3))) // another session val df2 = newSession.createDataFrame(Seq(1, 2)).toDF(Seq("a")) checkResult( df2.select(callUDF(permFuncName, df("a"))).async.collect().getResult(), - Seq(Row(2), Row(3)) - ) + Seq(Row(2), Row(3))) assertThrows[SnowflakeSQLException]( - df2.select(callUDF(tempFuncName, df("a"))).async.collect().getResult() - ) + df2.select(callUDF(tempFuncName, df("a"))).async.collect().getResult()) } finally { runQuery(s"drop function if exists $tempFuncName(INT)", session) @@ -485,8 +474,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { val rows = asyncJob1.getRows() assert( rows.length == 1 && rows.head.length == 1 && - rows.head.getString(0).contains(s"Table $tableName successfully created") - ) + rows.head.getString(0).contains(s"Table $tableName successfully created")) // all session.table checkAnswer(session.table(tableName), Seq(Row(1), Row(2), Row(3))) checkAnswer(session.table(Seq(db, sc, tableName)), Seq(Row(1), Row(2), Row(3))) @@ -529,8 +517,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { checkAnswer( session.table(tableName), Seq(Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), Row(2, "two", 2.2)), - false - ) + false) } finally { dropTable(tableName) } @@ -574,8 +561,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { checkAnswer( t2, Seq(Row(0, "A"), Row(0, "B"), Row(0, "C"), Row(4, "D"), Row(5, "E"), Row(6, "F")), - sort = false - ) + sort = false) testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName1) upperCaseData.write.mode(SaveMode.Overwrite).saveAsTable(tableName2) @@ -584,8 +570,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { checkAnswer( t2, Seq(Row(0, "A"), Row(0, "B"), Row(0, "C"), Row(4, "D"), Row(5, "E"), Row(6, "F")), - sort = false - ) + sort = false) upperCaseData.write.mode(SaveMode.Overwrite).saveAsTable(tableName2) import session.implicits._ @@ -594,8 +579,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { assert(asyncJob3.getResult() == UpdateResult(4, 0)) checkAnswer( t2, - Seq(Row(0, "A"), Row(0, "B"), Row(0, "D"), Row(0, "E"), Row(3, "C"), Row(6, "F")) - ) + Seq(Row(0, "A"), Row(0, "B"), Row(0, "D"), Row(0, "E"), Row(3, "C"), Row(6, "F"))) } // Copy UpdatableSuite.test("delete rows from table") and @@ -666,8 +650,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { assert(mergeBuilder.async.collect().getResult() == MergeResult(4, 2, 2)) checkAnswer( target, - Seq(Row(1, 21), Row(3, 3), Row(4, 4), Row(5, 25), Row(7, null), Row(null, 26)) - ) + Seq(Row(1, 21), Row(3, 3), Row(4, 4), Row(5, 25), Row(7, null), Row(null, 26))) } // Async executes Merge and get result in another session. @@ -704,8 +687,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { checkResult(mergeResultRows, Seq(Row(4, 2, 2))) checkAnswer( newSession.table(tableName), - Seq(Row(1, 21), Row(3, 3), Row(4, 4), Row(5, 25), Row(7, null), Row(null, 26)) - ) + Seq(Row(1, 21), Row(3, 3), Row(4, 4), Row(5, 25), Row(7, null), Row(null, 26))) } private def runCSvTestAsync( @@ -714,8 +696,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { options: Map[String, Any], expectedWriteResult: Array[Row], outputFileExtension: String, - saveMode: Option[SaveMode] = None - ) = { + saveMode: Option[SaveMode] = None) = { // Execute COPY INTO location and check result val writer = df.write.options(options) saveMode.foreach(writer.mode) @@ -733,8 +714,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { options: Map[String, Any], expectedWriteResult: Array[Row], outputFileExtension: String, - saveMode: Option[SaveMode] = None - ) = { + saveMode: Option[SaveMode] = None) = { // Execute COPY INTO location and check result val writer = df.write.options(options) saveMode.foreach(writer.mode) @@ -752,8 +732,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { options: Map[String, Any], expectedNumberOfRow: Int, outputFileExtension: String, - saveMode: Option[SaveMode] = None - ) = { + saveMode: Option[SaveMode] = None) = { // Execute COPY INTO location and check result val writer = df.write.options(options) saveMode.foreach(writer.mode) @@ -776,9 +755,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Seq( StructField("c1", IntegerType), StructField("c2", DoubleType), - StructField("c3", StringType) - ) - ) + StructField("c3", StringType))) val df = session.table(tableName) val path = s"@$targetStageName/p_${Random.nextInt().abs}" @@ -786,8 +763,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { runCSvTestAsync(df, path, Map.empty, Array(Row(3, 32, 46)), ".csv.gz") checkAnswer( session.read.schema(schema).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // by default, the mode is ErrorIfExist val ex = intercept[SnowflakeSQLException] { @@ -799,8 +775,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { runCSvTestAsync(df, path, Map.empty, Array(Row(3, 32, 46)), ".csv.gz", Some(SaveMode.Overwrite)) checkAnswer( session.read.schema(schema).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // test some file format options and values session.sql(s"remove $path").collect() @@ -808,13 +783,11 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { "FIELD_DELIMITER" -> "'aa'", "RECORD_DELIMITER" -> "bbbb", "COMPRESSION" -> "NONE", - "FILE_EXTENSION" -> "mycsv" - ) + "FILE_EXTENSION" -> "mycsv") runCSvTestAsync(df, path, options1, Array(Row(3, 47, 47)), ".mycsv") checkAnswer( session.read.schema(schema).options(options1).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // Test file format name only val fileFormatName = randomTableName() @@ -822,8 +795,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { .sql( s"CREATE OR REPLACE TEMPORARY FILE FORMAT $fileFormatName " + s"TYPE = CSV FIELD_DELIMITER = 'aa' RECORD_DELIMITER = 'bbbb' " + - s"COMPRESSION = 'NONE' FILE_EXTENSION = 'mycsv'" - ) + s"COMPRESSION = 'NONE' FILE_EXTENSION = 'mycsv'") .collect() runCSvTestAsync( df, @@ -831,20 +803,17 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Map("FORMAT_NAME" -> fileFormatName), Array(Row(3, 47, 47)), ".mycsv", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.schema(schema).options(options1).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // Test file format name and some extra format options val fileFormatName2 = randomTableName() session .sql( s"CREATE OR REPLACE TEMPORARY FILE FORMAT $fileFormatName2 " + - s"TYPE = CSV FIELD_DELIMITER = 'aa' RECORD_DELIMITER = 'bbbb'" - ) + s"TYPE = CSV FIELD_DELIMITER = 'aa' RECORD_DELIMITER = 'bbbb'") .collect() val formatNameAndOptions = Map("FORMAT_NAME" -> fileFormatName2, "COMPRESSION" -> "NONE", "FILE_EXTENSION" -> "mycsv") @@ -854,12 +823,10 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { formatNameAndOptions, Array(Row(3, 47, 47)), ".mycsv", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.schema(schema).options(options1).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) } // Copy DataFrameWriterSuiter.test("write JSON files: file format options"") and @@ -874,8 +841,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { runJsonTestAsync(df, path, Map.empty, Array(Row(2, 20, 40)), ".json.gz") checkAnswer( session.read.json(path), - Seq(Row("[\n 1,\n \"one\"\n]"), Row("[\n 2,\n \"two\"\n]")) - ) + Seq(Row("[\n 1,\n \"one\"\n]"), Row("[\n 2,\n \"two\"\n]"))) // write one column and overwrite val df2 = session.table(tableName).select(to_variant(col("c2"))) @@ -885,8 +851,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Map.empty, Array(Row(2, 12, 32)), ".json.gz", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer(session.read.json(path), Seq(Row("\"one\""), Row("\"two\""))) // write with format_name @@ -899,8 +864,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Map("FORMAT_NAME" -> formatName), Array(Row(2, 4, 24)), ".json.gz", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) session.read.json(path).show() checkAnswer(session.read.json(path), Seq(Row("1"), Row("2"))) @@ -913,8 +877,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Map("FORMAT_NAME" -> formatName, "FILE_EXTENSION" -> "myjson.json", "COMPRESSION" -> "NONE"), Array(Row(2, 4, 4)), ".myjson.json", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) session.read.json(path).show() checkAnswer(session.read.json(path), Seq(Row("1"), Row("2"))) } @@ -933,9 +896,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) // write with overwrite runParquetTest(df, path, Map.empty, 2, ".snappy.parquet", Some(SaveMode.Overwrite)) @@ -943,9 +904,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) // write with format_name val formatName = randomTableName() @@ -958,15 +917,12 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Map("FORMAT_NAME" -> formatName), 2, ".snappy.parquet", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) // write with format_name format and some extra option session.sql(s"rm $path").collect() @@ -977,15 +933,12 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { Map("FORMAT_NAME" -> formatName, "COMPRESSION" -> "LZO"), 2, ".lzo.parquet", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) } // Copy DataFrameWriterSuiter.test("Catch COPY INTO LOCATION output schema change") and diff --git a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala index 8d6b28fa..edc93edf 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala @@ -47,8 +47,7 @@ class ColumnSuite extends TestData { test("unary operators") { assert(testData1.select(-testData1("NUM")).collect() sameElements Array[Row](Row(-1), Row(-2))) assert( - testData1.select(!testData1("BOOL")).collect() sameElements Array[Row](Row(false), Row(true)) - ) + testData1.select(!testData1("BOOL")).collect() sameElements Array[Row](Row(false), Row(true))) } test("alias") { @@ -61,67 +60,45 @@ class ColumnSuite extends TestData { test("equal and not equal") { assert( testData1.where(testData1("BOOL") === true).collect() sameElements Array[Row]( - Row(1, true, "a") - ) - ) + Row(1, true, "a"))) assert( testData1.where(testData1("BOOL") equal_to lit(true)).collect() sameElements Array[Row]( - Row(1, true, "a") - ) - ) + Row(1, true, "a"))) assert( testData1.where(testData1("BOOL") =!= true).collect() sameElements Array[Row]( - Row(2, false, "b") - ) - ) + Row(2, false, "b"))) assert( testData1.where(testData1("BOOL") not_equal lit(true)).collect() sameElements Array[Row]( - Row(2, false, "b") - ) - ) + Row(2, false, "b"))) } test("gt and lt") { assert( - testData1.where(testData1("NUM") > 1).collect() sameElements Array[Row](Row(2, false, "b")) - ) + testData1.where(testData1("NUM") > 1).collect() sameElements Array[Row](Row(2, false, "b"))) assert( testData1.where(testData1("NUM") gt lit(1)).collect() sameElements Array[Row]( - Row(2, false, "b") - ) - ) + Row(2, false, "b"))) assert( - testData1.where(testData1("NUM") < 2).collect() sameElements Array[Row](Row(1, true, "a")) - ) + testData1.where(testData1("NUM") < 2).collect() sameElements Array[Row](Row(1, true, "a"))) assert( testData1.where(testData1("NUM") lt lit(2)).collect() sameElements Array[Row]( - Row(1, true, "a") - ) - ) + Row(1, true, "a"))) } test("leq and geq") { assert( - testData1.where(testData1("NUM") >= 2).collect() sameElements Array[Row](Row(2, false, "b")) - ) + testData1.where(testData1("NUM") >= 2).collect() sameElements Array[Row](Row(2, false, "b"))) assert( testData1.where(testData1("NUM") geq lit(2)).collect() sameElements Array[Row]( - Row(2, false, "b") - ) - ) + Row(2, false, "b"))) assert( - testData1.where(testData1("NUM") <= 1).collect() sameElements Array[Row](Row(1, true, "a")) - ) + testData1.where(testData1("NUM") <= 1).collect() sameElements Array[Row](Row(1, true, "a"))) assert( testData1.where(testData1("NUM") leq lit(1)).collect() sameElements Array[Row]( - Row(1, true, "a") - ) - ) + Row(1, true, "a"))) assert( testData1.where(testData1("NUM").between(lit(0), lit(1))).collect() sameElements Array[Row]( - Row(1, true, "a") - ) - ) + Row(1, true, "a"))) } @@ -130,13 +107,11 @@ class ColumnSuite extends TestData { assert( df.select(df("A") <=> df("B")).collect() sameElements - Array[Row](Row(false), Row(true), Row(true)) - ) + Array[Row](Row(false), Row(true), Row(true))) assert( df.select(df("A").equal_null(df("B"))).collect() sameElements - Array[Row](Row(false), Row(true), Row(true)) - ) + Array[Row](Row(false), Row(true), Row(true))) } test("NaN and Null") { @@ -148,26 +123,21 @@ class ColumnSuite extends TestData { assert( df.where(df("A").is_not_null).collect() sameElements Array[Row]( Row(1.1, 1), - Row(Double.NaN, 3) - ) - ) + Row(Double.NaN, 3))) } test("&& ||") { val df = session.sql( - "select * from values(true,true),(true,false),(false, true), (false, false) as T(a, b)" - ) + "select * from values(true,true),(true,false),(false, true), (false, false) as T(a, b)") assert(df.where(df("A") && df("B")).collect() sameElements Array[Row](Row(true, true))) assert(df.where(df("A") and df("B")).collect() sameElements Array[Row](Row(true, true))) assert( df.where(df("A") || df("B")) - .collect() sameElements Array[Row](Row(true, true), Row(true, false), Row(false, true)) - ) + .collect() sameElements Array[Row](Row(true, true), Row(true, false), Row(false, true))) assert( df.where(df("A") or df("B")) - .collect() sameElements Array[Row](Row(true, true), Row(true, false), Row(false, true)) - ) + .collect() sameElements Array[Row](Row(true, true), Row(true, false), Row(false, true))) } test("+ - * / %") { @@ -197,43 +167,35 @@ class ColumnSuite extends TestData { val sc = testData1.select(testData1("NUM").cast(StringType)).schema assert( sc == StructType( - Array(StructField("\"CAST (\"\"NUM\"\" AS STRING)\"", StringType, nullable = false)) - ) - ) + Array(StructField("\"CAST (\"\"NUM\"\" AS STRING)\"", StringType, nullable = false)))) } test("order") { assert( nullData1 .sort(nullData1("A").asc) - .collect() sameElements Array[Row](Row(null), Row(null), Row(1), Row(2), Row(3)) - ) + .collect() sameElements Array[Row](Row(null), Row(null), Row(1), Row(2), Row(3))) assert( nullData1 .sort(nullData1("A").asc_nulls_first) - .collect() sameElements Array[Row](Row(null), Row(null), Row(1), Row(2), Row(3)) - ) + .collect() sameElements Array[Row](Row(null), Row(null), Row(1), Row(2), Row(3))) assert( nullData1 .sort(nullData1("A").asc_nulls_last) - .collect() sameElements Array[Row](Row(1), Row(2), Row(3), Row(null), Row(null)) - ) + .collect() sameElements Array[Row](Row(1), Row(2), Row(3), Row(null), Row(null))) assert( nullData1 .sort(nullData1("A").desc) - .collect() sameElements Array[Row](Row(3), Row(2), Row(1), Row(null), Row(null)) - ) + .collect() sameElements Array[Row](Row(3), Row(2), Row(1), Row(null), Row(null))) assert( nullData1 .sort(nullData1("A").desc_nulls_last) - .collect() sameElements Array[Row](Row(3), Row(2), Row(1), Row(null), Row(null)) - ) + .collect() sameElements Array[Row](Row(3), Row(2), Row(1), Row(null), Row(null))) assert( nullData1 .sort(nullData1("A").desc_nulls_first) - .collect() sameElements Array[Row](Row(null), Row(null), Row(3), Row(2), Row(1)) - ) + .collect() sameElements Array[Row](Row(null), Row(null), Row(3), Row(2), Row(1))) } test("bitwise operator") { @@ -305,12 +267,11 @@ class ColumnSuite extends TestData { assert( df.drop(Seq.empty[String]).schema.fields.map(_.name).toSeq.sorted equals Seq( """"One"""", - """ONE""" - ) - ) + """ONE""")) assert( - df.drop(""""one"""").schema.fields.map(_.name).toSeq.sorted equals Seq(""""One"""", """ONE""") - ) + df.drop(""""one"""").schema.fields.map(_.name).toSeq.sorted equals Seq( + """"One"""", + """ONE""")) val ex = intercept[SnowparkClientException] { df.drop("ONE", """"One"""").collect() @@ -326,15 +287,11 @@ class ColumnSuite extends TestData { assert( df.drop(col(""""one"""")).schema.fields.map(_.name).toSeq.sorted equals Seq( """"One"""", - """ONE""" - ) - ) + """ONE""")) assert( df.drop(Seq.empty[Column]).schema.fields.map(_.name).toSeq.sorted equals Seq( """"One"""", - """ONE""" - ) - ) + """ONE""")) var ex = intercept[SnowparkClientException] { df.drop(df("ONE"), col(""""One"""")).collect() @@ -504,16 +461,14 @@ class ColumnSuite extends TestData { checkAnswer( string4.where(col("A").like(lit("%p%"))), Seq(Row("apple"), Row("peach")), - sort = false - ) + sort = false) } test("subfield") { checkAnswer( nullJson1.select(col("v")("a")), Seq(Row("null"), Row("\"foo\""), Row(null)), - sort = false - ) + sort = false) checkAnswer(array2.select(col("arr1")(0)), Seq(Row("1"), Row("6")), sort = false) @@ -523,13 +478,11 @@ class ColumnSuite extends TestData { checkAnswer( variant2.select(col("src")("vehicle")(0)("make")), Seq(Row("\"Honda\"")), - sort = false - ) + sort = false) checkAnswer( variant2.select(col("SRC")("vehicle")(0)("make")), Seq(Row("\"Honda\"")), - sort = false - ) + sort = false) checkAnswer(variant2.select(col("src")("vehicle")(0)("MAKE")), Seq(Row(null)), sort = false) checkAnswer(variant2.select(col("src")("VEHICLE")(0)("make")), Seq(Row(null)), sort = false) @@ -537,8 +490,7 @@ class ColumnSuite extends TestData { checkAnswer( variant2.select(col("src")("date with '' and .")), Seq(Row("\"2017-04-28\"")), - sort = false - ) + sort = false) // path is not accepted checkAnswer(variant2.select(col("src")("salesperson.id")), Seq(Row(null)), sort = false) @@ -564,11 +516,9 @@ class ColumnSuite extends TestData { .when(col("a").is_null, lit(5)) .when(col("a") === 1, lit(6)) .otherwise(lit(7)) - .as("a") - ), + .as("a")), Seq(Row(5), Row(7), Row(6), Row(7), Row(5)), - sort = false - ) + sort = false) checkAnswer( nullData1.select( @@ -576,11 +526,9 @@ class ColumnSuite extends TestData { .when(col("a").is_null, lit(5)) .when(col("a") === 1, lit(6)) .`else`(lit(7)) - .as("a") - ), + .as("a")), Seq(Row(5), Row(7), Row(6), Row(7), Row(5)), - sort = false - ) + sort = false) // empty otherwise checkAnswer( @@ -588,11 +536,9 @@ class ColumnSuite extends TestData { functions .when(col("a").is_null, lit(5)) .when(col("a") === 1, lit(6)) - .as("a") - ), + .as("a")), Seq(Row(5), Row(null), Row(6), Row(null), Row(5)), - sort = false - ) + sort = false) // wrong type, snowflake sql exception assertThrows[SnowflakeSQLException]( @@ -601,10 +547,8 @@ class ColumnSuite extends TestData { functions .when(col("a").is_null, lit("a")) .when(col("a") === 1, lit(6)) - .as("a") - ) - .collect() - ) + .as("a")) + .collect()) } test("lit contains '") { @@ -622,8 +566,7 @@ class ColumnSuite extends TestData { |------------------------- ||'616263' |'' |NULL | |------------------------- - |""".stripMargin - ) + |""".stripMargin) } test("In Expression 1: IN with constant value list") { @@ -640,8 +583,7 @@ class ColumnSuite extends TestData { ||1 |a |1 |1 | ||2 |b |2 |2 | |------------------------- - |""".stripMargin - ) + |""".stripMargin) // filter with NOT val df2 = df.filter(!col("a").in(Seq(lit(1), lit(2)))) @@ -652,8 +594,7 @@ class ColumnSuite extends TestData { |------------------------- ||3 |b |33 |33 | |------------------------- - |""".stripMargin - ) + |""".stripMargin) // select without NOT val df3 = df.select(col("a").in(Seq(1, 2)).as("in_result")) @@ -666,8 +607,7 @@ class ColumnSuite extends TestData { ||true | ||false | |--------------- - |""".stripMargin - ) + |""".stripMargin) // select with NOT val df4 = df.select(!col("a").in(Seq(1, 2)).as("in_result")) @@ -680,8 +620,7 @@ class ColumnSuite extends TestData { ||false | ||true | |--------------- - |""".stripMargin - ) + |""".stripMargin) } test("In Expression 2: In with sub query") { @@ -708,8 +647,7 @@ class ColumnSuite extends TestData { ||false | ||false | |--------------- - |""".stripMargin - ) + |""".stripMargin) // select with NOT val df4 = df.select(!df("a").in(df0.filter(col("a") < 2)).as("in_result")) @@ -722,8 +660,7 @@ class ColumnSuite extends TestData { ||true | ||true | |--------------- - |""".stripMargin - ) + |""".stripMargin) } test("In Expression 3: IN with all types") { @@ -745,9 +682,7 @@ class ColumnSuite extends TestData { StructField("boolean", BooleanType), StructField("binary", BinaryType), StructField("timestamp", TimestampType), - StructField("date", DateType) - ) - ) + StructField("date", DateType))) val timestamp: Long = 1606179541282L val largeData = new ArrayBuffer[Row]() @@ -766,9 +701,7 @@ class ColumnSuite extends TestData { true, Array(1.toByte, 2.toByte), new Timestamp(timestamp - 100), - new Date(timestamp - 100) - ) - ) + new Date(timestamp - 100))) } val df = session.createDataFrame(largeData, schema) @@ -786,17 +719,13 @@ class ColumnSuite extends TestData { col("decimal").in( Seq( new java.math.BigDecimal(1.2).setScale(3, RoundingMode.HALF_UP), - new java.math.BigDecimal(1.3).setScale(3, RoundingMode.HALF_UP) - ) - ) - ) + new java.math.BigDecimal(1.3).setScale(3, RoundingMode.HALF_UP)))) val df3 = df2.filter( col("boolean").in(Seq(true, false)) && col("binary").in(Seq(Array(1.toByte, 2.toByte), Array(2.toByte, 3.toByte))) && col("timestamp") .in(Seq(new Timestamp(timestamp - 100), new Timestamp(timestamp - 200))) && - col("date").in(Seq(new Date(timestamp - 100), new Date(timestamp - 200))) - ) + col("date").in(Seq(new Date(timestamp - 100), new Date(timestamp - 200)))) df3.show() // scalastyle:off @@ -808,8 +737,7 @@ class ColumnSuite extends TestData { ||1 |a |1 |2 |3 |4 |1.1 |1.2 |1.200 |true |'0102' |2020-11-23 16:59:01.182 |2020-11-23 | ||2 |a |1 |2 |3 |4 |1.1 |1.2 |1.200 |true |'0102' |2020-11-23 16:59:01.182 |2020-11-23 | |------------------------------------------------------------------------------------------------------------------------------------------------------ - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } finally { TimeZone.setDefault(oldTimeZone) @@ -868,8 +796,7 @@ class ColumnSuite extends TestData { df.select( functions .in(Seq(col("a"), col("c")), Seq(Seq(1, 1), Seq(2, 2), Seq(3, 3))) - .as("in_result") - ) + .as("in_result")) assert( getShowString(df3, 10, 50) == """--------------- @@ -879,15 +806,13 @@ class ColumnSuite extends TestData { ||true | ||false | |--------------- - |""".stripMargin - ) + |""".stripMargin) // select with NOT val df4 = df.select( (!functions.in(Seq(col("a"), col("c")), Seq(Seq(1, 1), Seq(2, 2), Seq(3, 3)))) - .as("in_result") - ) + .as("in_result")) assert( getShowString(df4, 10, 50) == """--------------- @@ -897,8 +822,7 @@ class ColumnSuite extends TestData { ||false | ||true | |--------------- - |""".stripMargin - ) + |""".stripMargin) } test("In Expression 7: multiple columns with sub query") { @@ -925,8 +849,7 @@ class ColumnSuite extends TestData { ||true | ||false | |--------------- - |""".stripMargin - ) + |""".stripMargin) // select with NOT val df4 = df.select((!functions.in(Seq(col("a"), col("b")), df0)).as("in_result")) @@ -939,8 +862,7 @@ class ColumnSuite extends TestData { ||false | ||true | |--------------- - |""".stripMargin - ) + |""".stripMargin) } // Below cases are supported by snowflake SQL, but they are confusing, diff --git a/src/test/scala/com/snowflake/snowpark_test/ComplexDataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ComplexDataFrameSuite.scala index 4e82bff9..1922a1f6 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ComplexDataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ComplexDataFrameSuite.scala @@ -8,8 +8,7 @@ trait ComplexDataFrameSuite extends SNTestBase { val tableName: String = randomName() val tmpStageName: String = randomStageName() private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) override def beforeAll: Unit = { super.beforeAll() @@ -37,13 +36,11 @@ trait ComplexDataFrameSuite extends SNTestBase { checkAnswer( df1.join(df2, "a").union(df2.filter($"a" < 2).union(df2.filter($"a" >= 2))), - Seq(Row(1, "test1"), Row(2, "test2")) - ) + Seq(Row(1, "test1"), Row(2, "test2"))) checkAnswer( df1.join(df2, "a").unionAll(df2.filter($"a" < 2).unionAll(df2.filter($"a" >= 2))), - Seq(Row(1, "test1"), Row(1, "test1"), Row(2, "test2"), Row(2, "test2")) - ) + Seq(Row(1, "test1"), Row(1, "test1"), Row(2, "test2"), Row(2, "test2"))) } test("Combination of multiple operators with filters") { @@ -56,8 +53,7 @@ trait ComplexDataFrameSuite extends SNTestBase { checkAnswer( df1.filter($"a" < 6).join(df2, "a").union(df2.filter($"a" > 5)), - (1 to 10).map(i => Row(i, s"test$i")) - ) + (1 to 10).map(i => Row(i, s"test$i"))) } test("join on top of unions") { @@ -68,8 +64,7 @@ trait ComplexDataFrameSuite extends SNTestBase { checkAnswer( df1.union(df2).join(df3.union(df4), "a").sort(functions.col("a")), (1 to 10).map(i => Row(i, s"test$i")), - false - ) + false) } test("Combination of multiple data sources") { @@ -78,13 +73,11 @@ trait ComplexDataFrameSuite extends SNTestBase { val df2 = session.table(tableName) checkAnswer( df2.join(df1, df1("a") === df2("num")), - Seq(Row(1, 1, "one", 1.2), Row(2, 2, "two", 2.2)) - ) + Seq(Row(1, 1, "one", 1.2), Row(2, 2, "two", 2.2))) checkAnswer( df2.filter($"num" === 1).join(df1.select("a", "b"), df1("a") === df2("num")), - Seq(Row(1, 1, "one")) - ) + Seq(Row(1, 1, "one"))) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala index 34ed6e53..460d20ae 100644 --- a/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala @@ -13,8 +13,7 @@ class CopyableDataFrameSuite extends SNTestBase { val testTableName: String = randomName() private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) override def beforeAll(): Unit = { super.beforeAll() @@ -66,8 +65,7 @@ class CopyableDataFrameSuite extends SNTestBase { .copyInto(testTableName) checkAnswer( session.table(testTableName), - Seq(Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), Row(2, "two", 2.2)) - ) + Seq(Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), Row(2, "two", 2.2))) } test("copy csv test: create target table automatically if not exists") { @@ -84,8 +82,7 @@ class CopyableDataFrameSuite extends SNTestBase { | |--A: Long (nullable = true) | |--B: String (nullable = true) | |--C: Double (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // run COPY again, the loaded files will be skipped by default df.copyInto(testTableName) @@ -107,8 +104,7 @@ class CopyableDataFrameSuite extends SNTestBase { | |--C1: Long (nullable = true) | |--C2: String (nullable = true) | |--C3: Double (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // run COPY again, the loaded files will be skipped by default df.copyInto(testTableName) @@ -118,8 +114,7 @@ class CopyableDataFrameSuite extends SNTestBase { df.write.saveAsTable(testTableName) checkAnswer( session.table(testTableName), - Seq(Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), Row(2, "two", 2.2)) - ) + Seq(Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), Row(2, "two", 2.2))) // Write data with saveAsTable() again, loaded file are NOT skipped. df.write.saveAsTable(testTableName) @@ -131,9 +126,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), - Row(2, "two", 2.2) - ) - ) + Row(2, "two", 2.2))) } test("copy csv test: copy transformation") { @@ -150,8 +143,7 @@ class CopyableDataFrameSuite extends SNTestBase { | |--C1: String (nullable = true) | |--C2: String (nullable = true) | |--C3: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row("1", "one", "1.2"), Row("2", "two", "2.2"))) // Copy data in order of $3, $2, $1 with FORCE = TRUE @@ -162,16 +154,13 @@ class CopyableDataFrameSuite extends SNTestBase { Row("1", "one", "1.2"), Row("2", "two", "2.2"), Row("1.2", "one", "1"), - Row("2.2", "two", "2") - ) - ) + Row("2.2", "two", "2"))) // Copy data in order of $2, $3, $1 with FORCE = TRUE and skip_header = 1 df.copyInto( testTableName, Seq(col("$2"), col("$3"), col("$1")), - Map("FORCE" -> "TRUE", "skip_header" -> 1) - ) + Map("FORCE" -> "TRUE", "skip_header" -> 1)) checkAnswer( session.table(testTableName), Seq( @@ -179,9 +168,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("2", "two", "2.2"), Row("1.2", "one", "1"), Row("2.2", "two", "2"), - Row("two", "2.2", "2") - ) - ) + Row("two", "2.2", "2"))) } test("copy csv test: negative test") { @@ -196,9 +183,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex1.errorCode.equals("0122") && ex1.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 2: copyInto transformation doesn't match table schema. createTable(testTableName, "c1 String") @@ -231,8 +216,7 @@ class CopyableDataFrameSuite extends SNTestBase { | |--C1: String (nullable = true) | |--C2: String (nullable = true) | |--C3: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row("1", "one", null), Row("2", "two", null))) // Copy data in order of $3, $2 to column c3 and c2 with FORCE = TRUE @@ -243,17 +227,14 @@ class CopyableDataFrameSuite extends SNTestBase { Row("1", "one", null), Row("2", "two", null), Row(null, "one", "1.2"), - Row(null, "two", "2.2") - ) - ) + Row(null, "two", "2.2"))) // Copy data $1 to column c3 with FORCE = TRUE and skip_header = 1 df.copyInto( testTableName, Seq("c3"), Seq(col("$1")), - Map("FORCE" -> "TRUE", "skip_header" -> 1) - ) + Map("FORCE" -> "TRUE", "skip_header" -> 1)) checkAnswer( session.table(testTableName), Seq( @@ -261,9 +242,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("2", "two", null), Row(null, "one", "1.2"), Row(null, "two", "2.2"), - Row(null, null, "2") - ) - ) + Row(null, null, "2"))) } test("copy csv test: copy into column names without transformation") { @@ -281,12 +260,10 @@ class CopyableDataFrameSuite extends SNTestBase { | |--C2: String (nullable = true) | |--C3: String (nullable = true) | |--C4: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), - Seq(Row("1", "one", "1.2", null), Row("2", "two", "2.2", null)) - ) + Seq(Row("1", "one", "1.2", null), Row("2", "two", "2.2", null))) // case 2: select more columns from csv than it have // There is only 3 columns in the schema of the csv file. @@ -297,9 +274,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("1", "one", "1.2", null), Row("2", "two", "2.2", null), Row("1", "one", "1.2", null), - Row("2", "two", "2.2", null) - ) - ) + Row("2", "two", "2.2", null))) } test("copy json test: write with column names") { @@ -312,16 +287,14 @@ class CopyableDataFrameSuite extends SNTestBase { testTableName, Seq("c1", "c2"), Seq(sqlExpr("$1:color"), sqlExpr("$1:fruit")), - Map.empty - ) + Map.empty) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--C1: String (nullable = true) | |--C2: Variant (nullable = true) | |--C3: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row("Red", "\"Apple\"", null))) } @@ -337,14 +310,13 @@ class CopyableDataFrameSuite extends SNTestBase { | |--C1: String (nullable = true) | |--C2: Variant (nullable = true) | |--C3: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) val ex = intercept[SnowflakeSQLException] { df.copyInto(testTableName, Seq("c1", "c2"), Seq.empty, Map.empty) } assert( - ex.getMessage.contains("JSON file format can produce one and only one column of type variant") - ) + ex.getMessage.contains( + "JSON file format can produce one and only one column of type variant")) } test("copy csv test: negative test with column names") { @@ -358,21 +330,15 @@ class CopyableDataFrameSuite extends SNTestBase { df.copyInto(testTableName, Seq("c1"), Seq(col("$1"), col("$2")), Map.empty) } assert( - ex2.getMessage.contains( - "Number of column names provided to copy " + - "into does not match the number of transformations" - ) - ) + ex2.getMessage.contains("Number of column names provided to copy " + + "into does not match the number of transformations")) // table has 3 column, transformation has 2 columns, column name has 3 val ex3 = intercept[SnowparkClientException] { df.copyInto(testTableName, Seq("c1", "c2", "c3"), Seq(col("$1"), col("$2")), Map.empty) } assert( - ex3.getMessage.contains( - "Number of column names provided to copy " + - "into does not match the number of transformations" - ) - ) + ex3.getMessage.contains("Number of column names provided to copy " + + "into does not match the number of transformations")) // case 2: column names contains unknown columns // table has 3 column, transformation has 4 columns, column name has 4 @@ -381,8 +347,7 @@ class CopyableDataFrameSuite extends SNTestBase { testTableName, Seq("c1", "c2", "c3", "c4"), Seq(col("$1"), col("$2"), col("$3"), col("$4")), - Map.empty - ) + Map.empty) } assert(ex4.getMessage.contains("invalid identifier 'C4'")) } @@ -409,19 +374,16 @@ class CopyableDataFrameSuite extends SNTestBase { TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--C1: Variant (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), - Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}")) - ) + Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}"))) // copy again: loaded file is skipped. df.copyInto(testTableName, Seq(col("$1").as("B"))) checkAnswer( session.table(testTableName), - Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}")) - ) + Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}"))) // copy again with FORCE = true. df.copyInto(testTableName, Seq(col("$1").as("B")), Map("FORCE" -> true)) @@ -429,9 +391,7 @@ class CopyableDataFrameSuite extends SNTestBase { session.table(testTableName), Seq( Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}"), - Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}") - ) - ) + Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}"))) } test("copy json test: write with transformation") { @@ -445,17 +405,14 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( sqlExpr("$1:color").as("color"), sqlExpr("$1:fruit").as("fruit"), - sqlExpr("$1:size").as("size") - ) - ) + sqlExpr("$1:size").as("size"))) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--C1: String (nullable = true) | |--C2: Variant (nullable = true) | |--C3: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row("Red", "\"Apple\"", "Large"))) // copy again with existed table and FORCE = true. @@ -465,22 +422,18 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( sqlExpr("$1:size").as("size"), sqlExpr("$1:fruit").as("fruit"), - sqlExpr("$1:color").as("color") - ), - Map("FORCE" -> true) - ) + sqlExpr("$1:color").as("color")), + Map("FORCE" -> true)) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--C1: String (nullable = true) | |--C2: Variant (nullable = true) | |--C3: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), - Seq(Row("Red", "\"Apple\"", "Large"), Row("Large", "\"Apple\"", "Red")) - ) + Seq(Row("Red", "\"Apple\"", "Large"), Row("Large", "\"Apple\"", "Red"))) } test("copy json test: negative test") { @@ -495,9 +448,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex1.errorCode.equals("0122") && ex1.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 2: copy with transformation when target table doesn't exist dropTable(testTableName)(session) @@ -507,9 +458,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex2.errorCode.equals("0122") && ex2.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 3: COPY transformation doesn't match target table createTable(testTableName, "c1 String") @@ -530,15 +479,12 @@ class CopyableDataFrameSuite extends SNTestBase { TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--A: Variant (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) // copy again, skip loaded files df.copyInto(testTableName, Seq(col("$1").as("A"))) @@ -546,9 +492,7 @@ class CopyableDataFrameSuite extends SNTestBase { session.table(testTableName), Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) // copy again with FORCE = true. df.copyInto(testTableName, Seq(col("$1").as("B")), Map("FORCE" -> true)) @@ -558,9 +502,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"), Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) } test("copy parquet test: write with transformation") { @@ -573,9 +515,7 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( sqlExpr("$1:num").cast(IntegerType).as("num"), sqlExpr("$1:str").as("str"), - length(sqlExpr("$1:str")).as("str_length") - ) - ) + length(sqlExpr("$1:str")).as("str_length"))) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == @@ -583,8 +523,7 @@ class CopyableDataFrameSuite extends SNTestBase { | |--NUM: Long (nullable = true) | |--STR: Variant (nullable = true) | |--STR_LENGTH: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4))) // copy again with existed table and FORCE = true. @@ -594,18 +533,15 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( length(sqlExpr("$1:str")).as("str_length"), sqlExpr("$1:str").as("str"), - sqlExpr("$1:num").cast(IntegerType).as("num") - ), - Map("FORCE" -> true) - ) + sqlExpr("$1:num").cast(IntegerType).as("num")), + Map("FORCE" -> true)) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--NUM: Long (nullable = true) | |--STR: Variant (nullable = true) | |--STR_LENGTH: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), @@ -613,9 +549,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4), Row(4, "\"str1\"", 1), - Row(4, "\"str2\"", 2) - ) - ) + Row(4, "\"str2\"", 2))) } test("copy parquet test: negative test") { @@ -630,9 +564,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex1.errorCode.equals("0122") && ex1.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 2: copy with transformation when target table doesn't exist dropTable(testTableName)(session) @@ -642,9 +574,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex2.errorCode.equals("0122") && ex2.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 3: COPY transformation doesn't match target table createTable(testTableName, "c1 String") @@ -665,15 +595,12 @@ class CopyableDataFrameSuite extends SNTestBase { TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--A: Variant (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) // copy again, skip loaded files df.copyInto(testTableName, Seq(col("$1").as("A"))) @@ -681,9 +608,7 @@ class CopyableDataFrameSuite extends SNTestBase { session.table(testTableName), Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) // copy again with FORCE = true. df.copyInto(testTableName, Seq(col("$1").as("B")), Map("FORCE" -> true)) @@ -693,9 +618,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"), Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) } test("copy avro test: write with transformation") { @@ -708,9 +631,7 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( sqlExpr("$1:num").cast(IntegerType).as("num"), sqlExpr("$1:str").as("str"), - length(sqlExpr("$1:str")).as("str_length") - ) - ) + length(sqlExpr("$1:str")).as("str_length"))) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == @@ -718,8 +639,7 @@ class CopyableDataFrameSuite extends SNTestBase { | |--NUM: Long (nullable = true) | |--STR: Variant (nullable = true) | |--STR_LENGTH: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4))) // copy again with existed table and FORCE = true. @@ -729,27 +649,22 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( length(sqlExpr("$1:str")).as("str_length"), sqlExpr("$1:str").as("str"), - sqlExpr("$1:num").cast(IntegerType).as("num") - ), - Map("FORCE" -> true) - ) + sqlExpr("$1:num").cast(IntegerType).as("num")), + Map("FORCE" -> true)) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--NUM: Long (nullable = true) | |--STR: Variant (nullable = true) | |--STR_LENGTH: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), Seq( Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4), Row(4, "\"str1\"", 1), - Row(4, "\"str2\"", 2) - ) - ) + Row(4, "\"str2\"", 2))) } test("copy avro test: negative test") { @@ -764,9 +679,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex1.errorCode.equals("0122") && ex1.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 2: copy with transformation when target table doesn't exist dropTable(testTableName)(session) @@ -776,9 +689,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex2.errorCode.equals("0122") && ex2.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 3: COPY transformation doesn't match target table createTable(testTableName, "c1 String") @@ -799,15 +710,12 @@ class CopyableDataFrameSuite extends SNTestBase { TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--A: Variant (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) // copy again, skip loaded files df.copyInto(testTableName, Seq(col("$1").as("A"))) @@ -815,9 +723,7 @@ class CopyableDataFrameSuite extends SNTestBase { session.table(testTableName), Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) // copy again with FORCE = true. df.copyInto(testTableName, Seq(col("$1").as("B")), Map("FORCE" -> true)) @@ -827,9 +733,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"), Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ) - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) } test("copy orc test: write with transformation") { @@ -842,9 +746,7 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( sqlExpr("$1:num").cast(IntegerType).as("num"), sqlExpr("$1:str").as("str"), - length(sqlExpr("$1:str")).as("str_length") - ) - ) + length(sqlExpr("$1:str")).as("str_length"))) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == @@ -852,8 +754,7 @@ class CopyableDataFrameSuite extends SNTestBase { | |--NUM: Long (nullable = true) | |--STR: Variant (nullable = true) | |--STR_LENGTH: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4))) // copy again with existed table and FORCE = true. @@ -863,27 +764,22 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( length(sqlExpr("$1:str")).as("str_length"), sqlExpr("$1:str").as("str"), - sqlExpr("$1:num").cast(IntegerType).as("num") - ), - Map("FORCE" -> true) - ) + sqlExpr("$1:num").cast(IntegerType).as("num")), + Map("FORCE" -> true)) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--NUM: Long (nullable = true) | |--STR: Variant (nullable = true) | |--STR_LENGTH: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), Seq( Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4), Row(4, "\"str1\"", 1), - Row(4, "\"str2\"", 2) - ) - ) + Row(4, "\"str2\"", 2))) } test("copy orc test: negative test") { @@ -898,9 +794,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex1.errorCode.equals("0122") && ex1.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 2: copy with transformation when target table doesn't exist dropTable(testTableName)(session) @@ -910,9 +804,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex2.errorCode.equals("0122") && ex2.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 3: COPY transformation doesn't match target table createTable(testTableName, "c1 String") @@ -933,15 +825,12 @@ class CopyableDataFrameSuite extends SNTestBase { TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--A: Variant (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( session.table(testTableName), Seq( Row("\n 1\n str1\n"), - Row("\n 2\n str2\n") - ) - ) + Row("\n 2\n str2\n"))) // copy again, skip loaded files df.copyInto(testTableName, Seq(col("$1").as("A"))) @@ -949,9 +838,7 @@ class CopyableDataFrameSuite extends SNTestBase { session.table(testTableName), Seq( Row("\n 1\n str1\n"), - Row("\n 2\n str2\n") - ) - ) + Row("\n 2\n str2\n"))) // copy again with FORCE = true. df.copyInto(testTableName, Seq(col("$1").as("B")), Map("FORCE" -> true)) @@ -961,9 +848,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("\n 1\n str1\n"), Row("\n 2\n str2\n"), Row("\n 1\n str1\n"), - Row("\n 2\n str2\n") - ) - ) + Row("\n 2\n str2\n"))) } test("copy xml test: write with transformation") { @@ -976,17 +861,14 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( get(xmlget(col("$1"), lit("num"), lit(0)), lit("$")).cast(IntegerType).as("num"), get(xmlget(col("$1"), lit("str"), lit(0)), lit("$")).as("str"), - length(get(xmlget(col("$1"), lit("str"), lit(0)), lit("$"))).as("str_length") - ) - ) + length(get(xmlget(col("$1"), lit("str"), lit(0)), lit("$"))).as("str_length"))) assert( TestUtils.treeString(session.table(testTableName).schema, 0) == s"""root | |--NUM: Long (nullable = true) | |--STR: Variant (nullable = true) | |--STR_LENGTH: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(session.table(testTableName), Seq(Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4))) // copy again, skip loaded files @@ -995,9 +877,7 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( get(xmlget(col("$1"), lit("num"), lit(0)), lit("$")).cast(IntegerType).as("num"), get(xmlget(col("$1"), lit("str"), lit(0)), lit("$")).as("str"), - length(get(xmlget(col("$1"), lit("str"), lit(0)), lit("$"))).as("str_length") - ) - ) + length(get(xmlget(col("$1"), lit("str"), lit(0)), lit("$"))).as("str_length"))) checkAnswer(session.table(testTableName), Seq(Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4))) // copy again with existed table and FORCE = true. @@ -1007,19 +887,15 @@ class CopyableDataFrameSuite extends SNTestBase { Seq( length(get(xmlget(col("$1"), lit("str"), lit(0)), lit("$"))).as("str_length"), get(xmlget(col("$1"), lit("str"), lit(0)), lit("$")).as("str"), - get(xmlget(col("$1"), lit("num"), lit(0)), lit("$")).cast(IntegerType).as("num") - ), - Map("FORCE" -> true) - ) + get(xmlget(col("$1"), lit("num"), lit(0)), lit("$")).cast(IntegerType).as("num")), + Map("FORCE" -> true)) checkAnswer( session.table(testTableName), Seq( Row(1, "\"str1\"", 4), Row(2, "\"str2\"", 4), Row(4, "\"str1\"", 1), - Row(4, "\"str2\"", 2) - ) - ) + Row(4, "\"str2\"", 2))) } test("copy xml test: negative test") { @@ -1034,9 +910,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex1.errorCode.equals("0122") && ex1.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 2: copy with transformation when target table doesn't exist dropTable(testTableName)(session) @@ -1046,9 +920,7 @@ class CopyableDataFrameSuite extends SNTestBase { assert( ex2.errorCode.equals("0122") && ex2.message.contains( s"Cannot create the target table $testTableName because Snowpark cannot determine" + - " the column names to use. You should create the table before calling copyInto()." - ) - ) + " the column names to use. You should create the table before calling copyInto().")) // case 3: COPY transformation doesn't match target table createTable(testTableName, "c1 String") @@ -1120,22 +992,19 @@ class CopyableDataFrameSuite extends SNTestBase { val asyncJob2 = df.async.copyInto( testTableName, Seq(col("$1"), col("$1"), col("$1")), - Map("skip_header" -> 1, "FORCE" -> "true") - ) + Map("skip_header" -> 1, "FORCE" -> "true")) val res2 = asyncJob2.getResult() assert(res2.isInstanceOf[Unit]) // Check result in target table checkAnswer( session.table(testTableName), - Seq(Row("1.2", "one", "1"), Row("2.2", "two", "2"), Row("2", "2", "2")) - ) + Seq(Row("1.2", "one", "1"), Row("2.2", "two", "2"), Row("2", "2", "2"))) // copy data with transformation, options and target columns val asyncJob3 = df.async.copyInto( testTableName, Seq("c3", "c2", "c1"), Seq(length(col("$1")), col("$2"), length(col("$3"))), - Map("FORCE" -> "true") - ) + Map("FORCE" -> "true")) asyncJob3.getResult() val res3 = asyncJob3.getResult() assert(res3.isInstanceOf[Unit]) @@ -1147,9 +1016,7 @@ class CopyableDataFrameSuite extends SNTestBase { Row("2.2", "two", "2"), Row("2", "2", "2"), Row("3", "one", "1"), - Row("3", "two", "1") - ) - ) + Row("3", "two", "1"))) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala index 605a37bf..7418701a 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala @@ -17,8 +17,7 @@ class DataFrameAggregateSuite extends TestData { .agg(sum(col("amount"))) .sort(col("empid")), Seq(Row(1, 10400, 8000, 11000, 18000), Row(2, 39500, 90700, 12000, 5300)), - sort = false - ) + sort = false) // multiple aggregation isn't supported val e = intercept[SnowparkClientException]( @@ -26,14 +25,10 @@ class DataFrameAggregateSuite extends TestData { .pivot("month", Seq("JAN", "FEB", "MAR", "APR")) .agg(sum(col("amount")), avg(col("amount"))) .sort(col("empid")) - .collect() - ) + .collect()) assert( - e.getMessage.contains( - "You can apply only one aggregate expression to a" + - " RelationalGroupedDataFrame returned by the pivot() method." - ) - ) + e.getMessage.contains("You can apply only one aggregate expression to a" + + " RelationalGroupedDataFrame returned by the pivot() method.")) } test("join on pivot") { @@ -47,8 +42,7 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( df1.join(df2, "empid"), Seq(Row(1, 10400, 8000, 11000, 18000, 12345), Row(2, 39500, 90700, 12000, 5300, 67890)), - sort = false - ) + sort = false) } test("pivot on join") { @@ -60,8 +54,7 @@ class DataFrameAggregateSuite extends TestData { .agg(sum(col("amount"))) .sort(col("name")), Seq(Row(1, "One", 10400, 8000, 11000, 18000), Row(2, "Two", 39500, 90700, 12000, 5300)), - sort = false - ) + sort = false) } test("RelationalGroupedDataFrame.agg()") { @@ -94,8 +87,7 @@ class DataFrameAggregateSuite extends TestData { .groupBy("radio_license") .agg(count(col("*")).as("count")) .withColumn("medical_license", lit(null)) - .select("medical_license", "radio_license", "count") - ) + .select("medical_license", "radio_license", "count")) .sort(col("count")) .collect() @@ -115,10 +107,8 @@ class DataFrameAggregateSuite extends TestData { Row("RN", null, 2), Row(null, "Technician", 2), Row(null, null, 3), - Row("LVN", null, 5) - ), - sort = false - ) + Row("LVN", null, 5)), + sort = false) // comparing with groupBy checkAnswer( @@ -132,18 +122,15 @@ class DataFrameAggregateSuite extends TestData { Row(1, "RN", "Amateur Extra"), Row(1, "RN", null), Row(2, "LVN", "Technician"), - Row(2, "LVN", null) - ), - sort = false - ) + Row(2, "LVN", null)), + sort = false) // mixed grouping expression checkAnswer( nurse .groupByGroupingSets( GroupingSets(Set(col("medical_license"), col("radio_license"))), - GroupingSets(Set(col("radio_license"))) - ) + GroupingSets(Set(col("radio_license")))) // duplicated column should be removed in the result .agg(col("radio_license")) .sort(col("radio_license")), @@ -152,10 +139,8 @@ class DataFrameAggregateSuite extends TestData { Row("RN", null), Row("RN", "Amateur Extra"), Row("LVN", "General"), - Row("LVN", "Technician") - ), - sort = false - ) + Row("LVN", "Technician")), + sort = false) // default constructor checkAnswer( @@ -163,9 +148,7 @@ class DataFrameAggregateSuite extends TestData { .groupByGroupingSets( Seq( GroupingSets(Seq(Set(col("medical_license"), col("radio_license")))), - GroupingSets(Seq(Set(col("radio_license")))) - ) - ) + GroupingSets(Seq(Set(col("radio_license")))))) // duplicated column should be removed in the result .agg(col("radio_license")) .sort(col("radio_license")), @@ -174,10 +157,8 @@ class DataFrameAggregateSuite extends TestData { Row("RN", null), Row("RN", "Amateur Extra"), Row("LVN", "General"), - Row("LVN", "Technician") - ), - sort = false - ) + Row("LVN", "Technician")), + sort = false) } test("RelationalGroupedDataFrame.max()") { @@ -235,8 +216,7 @@ class DataFrameAggregateSuite extends TestData { ("b", 2, 22, "c"), ("a", 3, 33, "d"), ("b", 4, 44, "e"), - ("b", 44, 444, "f") - ) + ("b", 44, 444, "f")) .toDF("key", "value1", "value2", "rest") // call median without group-by @@ -254,8 +234,7 @@ class DataFrameAggregateSuite extends TestData { // no arguments checkAnswer( df.groupBy("a").builtin("max")(col("a"), col("b")), - Seq(Row(1, 1, 13), Row(2, 2, 12)) - ) + Seq(Row(1, 1, 13), Row(2, 2, 12))) // with arguments checkAnswer(df.groupBy("a").builtin("max")(col("b")), Seq(Row(1, 13), Row(2, 12))) } @@ -285,15 +264,13 @@ class DataFrameAggregateSuite extends TestData { testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), Row(2, 1, 0) :: Row(2, null, 0) :: Row(3, 1, 1) :: Row(3, 2, -1) :: Row(3, null, 0) :: Row(4, 1, 2) :: Row(4, 2, 0) :: Row(4, null, 2) :: Row(5, 2, 1) - :: Row(5, null, 1) :: Row(null, null, 3) :: Nil - ) + :: Row(5, null, 1) :: Row(null, null, 3) :: Nil) checkAnswer( testData2.rollup("a", "b").agg(sum(col("b"))), Row(1, 1, 1) :: Row(1, 2, 2) :: Row(1, null, 3) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(2, null, 3) :: Row(3, 1, 1) :: Row(3, 2, 2) :: Row(3, null, 3) - :: Row(null, null, 9) :: Nil - ) + :: Row(null, null, 9) :: Nil) } test("cube overlapping columns") { @@ -302,15 +279,13 @@ class DataFrameAggregateSuite extends TestData { Row(2, 1, 0) :: Row(2, null, 0) :: Row(3, 1, 1) :: Row(3, 2, -1) :: Row(3, null, 0) :: Row(4, 1, 2) :: Row(4, 2, 0) :: Row(4, null, 2) :: Row(5, 2, 1) :: Row(5, null, 1) :: Row(null, 1, 3) :: Row(null, 2, 0) - :: Row(null, null, 3) :: Nil - ) + :: Row(null, null, 3) :: Nil) checkAnswer( testData2.cube("a", "b").agg(sum(col("b"))), Row(1, 1, 1) :: Row(1, 2, 2) :: Row(1, null, 3) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(2, null, 3) :: Row(3, 1, 1) :: Row(3, 2, 2) :: Row(3, null, 3) - :: Row(null, 1, 3) :: Row(null, 2, 6) :: Row(null, null, 9) :: Nil - ) + :: Row(null, 1, 3) :: Row(null, 2, 6) :: Row(null, null, 9) :: Nil) } test("null count") { @@ -321,13 +296,11 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( testData3 .agg(count($"a"), count($"b"), count(lit(1)), count_distinct($"a"), count_distinct($"b")), - Seq(Row(2, 1, 2, 2, 1)) - ) + Seq(Row(2, 1, 2, 2, 1))) checkAnswer( testData3.agg(count($"b"), count_distinct($"b"), sum_distinct($"b")), // non-partial - Seq(Row(1, 1, 2)) - ) + Seq(Row(1, 1, 2))) } // Used temporary VIEW which is not supported by owner's mode stored proc yet @@ -343,65 +316,52 @@ class DataFrameAggregateSuite extends TestData { checkWindowError( testData2 .groupBy($"a") - .agg(max(sum(sum($"b")).over(Window.orderBy($"a")))) - ) + .agg(max(sum(sum($"b")).over(Window.orderBy($"a"))))) checkWindowError( testData2 .groupBy($"a") .agg( sum($"b").as("s"), - max( - count(col("*")) - .over() - ) - ) - .where($"s" === 3) - ) + max(count(col("*")) + .over())) + .where($"s" === 3)) checkAnswer( testData2 .groupBy($"a") .agg(max($"b"), sum($"b").as("s"), count(col("*")).over()) .where($"s" === 3), - Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil - ) + Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) testData2.createOrReplaceTempView("testData2") checkWindowError(session.sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) checkWindowError(session.sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) checkWindowError( - session.sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a") - ) + session.sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) checkWindowError( - session.sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a") - ) + session.sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) checkWindowError( - session.sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3") - ) + session.sql( + "SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) checkAnswer( session.sql( - "SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3" - ), - Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil - ) + "SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), + Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) } test("distinct") { val df = Seq((1, "one", 1.0), (2, "one", 2.0), (2, "two", 1.0)).toDF("i", "s", """"i"""") checkAnswer( df.distinct(), - Row(1, "one", 1.0) :: Row(2, "one", 2.0) :: Row(2, "two", 1.0) :: Nil - ) + Row(1, "one", 1.0) :: Row(2, "one", 2.0) :: Row(2, "two", 1.0) :: Nil) checkAnswer(df.select("i").distinct(), Row(1) :: Row(2) :: Nil) checkAnswer(df.select(""""i"""").distinct(), Row(1.0) :: Row(2.0) :: Nil) checkAnswer(df.select("s").distinct(), Row("one") :: Row("two") :: Nil) checkAnswer( df.select("i", """"i"""").distinct(), - Row(1, 1.0) :: Row(2, 1.0) :: Row(2, 2.0) :: Nil - ) + Row(1, 1.0) :: Row(2, 1.0) :: Row(2, 2.0) :: Nil) checkAnswer( df.select("i", """"i"""").distinct(), - Row(1, 1.0) :: Row(2, 1.0) :: Row(2, 2.0) :: Nil - ) + Row(1, 1.0) :: Row(2, 1.0) :: Row(2, 2.0) :: Nil) checkAnswer(df.filter($"i" < 0).distinct(), Nil) } @@ -411,19 +371,16 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( lhs.join(rhs, lhs("i") === rhs("i")).distinct(), - Row(1, "one", 1.0, 1, "one", 1.0) :: Row(2, "one", 2.0, 2, "one", 2.0) :: Nil - ) + Row(1, "one", 1.0, 1, "one", 1.0) :: Row(2, "one", 2.0, 2, "one", 2.0) :: Nil) val lhsD = lhs.select($"s").distinct() checkAnswer( lhsD.join(rhs, lhsD("s") === rhs("s")), - Row("one", 1, "one", 1.0) :: Row("one", 2, "one", 2.0) :: Nil - ) + Row("one", 1, "one", 1.0) :: Row("one", 2, "one", 2.0) :: Nil) var rhsD = rhs.select($"s") checkAnswer( lhsD.join(rhsD, lhsD("s") === rhsD("s")), - Row("one", "one") :: Row("one", "one") :: Nil - ) + Row("one", "one") :: Row("one", "one") :: Nil) rhsD = rhs.select($"s").distinct() checkAnswer(lhsD.join(rhsD, lhsD("s") === rhsD("s")), Row("one", "one") :: Nil) @@ -434,12 +391,10 @@ class DataFrameAggregateSuite extends TestData { checkAnswer(testData2.groupBy("a").agg(count($"*")), Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil) checkAnswer( testData2.groupBy("a").agg(Map($"*" -> "count")), - Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil - ) + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil) checkAnswer( testData2.groupBy("a").agg(Map($"b" -> "sum")), - Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil - ) + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil) val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) .toDF("key", "value1", "value2", "rest") @@ -451,9 +406,7 @@ class DataFrameAggregateSuite extends TestData { Seq( Row(new java.math.BigDecimal(1), new java.math.BigDecimal(3)), Row(new java.math.BigDecimal(2), new java.math.BigDecimal(3)), - Row(new java.math.BigDecimal(3), new java.math.BigDecimal(3)) - ) - ) + Row(new java.math.BigDecimal(3), new java.math.BigDecimal(3)))) } test("SN - agg should be ordering preserving") { @@ -474,17 +427,14 @@ class DataFrameAggregateSuite extends TestData { Row("dotNET", null, 63000.0) :: Row(null, 2012, 35000.0) :: Row(null, 2013, 78000.0) :: - Row(null, null, 113000.0) :: Nil - ) + Row(null, null, 113000.0) :: Nil) val df0 = session.createDataFrame( Seq( Fact(20151123, 18, 35, "room1", 18.6), Fact(20151123, 18, 35, "room2", 22.4), Fact(20151123, 18, 36, "room1", 17.4), - Fact(20151123, 18, 36, "room2", 25.6) - ) - ) + Fact(20151123, 18, 36, "room2", 25.6))) val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map($"temp" -> "avg")) assert(cube0.where(col("date").is_null).count > 0) @@ -503,8 +453,7 @@ class DataFrameAggregateSuite extends TestData { Row("dotNET", null, 0, 1, 1) :: Row(null, 2012, 1, 0, 2) :: Row(null, 2013, 1, 0, 2) :: - Row(null, null, 1, 1, 3) :: Nil - ) + Row(null, null, 1, 1, 3) :: Nil) // use column reference in `grouping_id` instead of column name checkAnswer( @@ -519,8 +468,7 @@ class DataFrameAggregateSuite extends TestData { Row("dotNET", null, 1) :: Row(null, 2012, 2) :: Row(null, 2013, 2) :: - Row(null, null, 3) :: Nil - ) + Row(null, null, 3) :: Nil) /* TODO: Add another test with eager analysis */ @@ -538,16 +486,14 @@ class DataFrameAggregateSuite extends TestData { test("SN - count") { checkAnswer( testData2.agg(count($"a"), sum_distinct($"a")), // non-partial - Seq(Row(6, 6.0)) - ) + Seq(Row(6, 6.0))) } test("SN - stddev") { val testData2ADev = math.sqrt(4.0 / 5.0) checkAnswer( testData2.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")), - Seq(Row(testData2ADev, 0.8164967850518458, testData2ADev)) - ) + Seq(Row(testData2ADev, 0.8164967850518458, testData2ADev))) } test("SN - moments") { @@ -559,13 +505,11 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( testData2.groupBy($"a").agg(variance($"b")), - Row(1, 0.50000) :: Row(2, 0.50000) :: Row(3, 0.500000) :: Nil - ) + Row(1, 0.50000) :: Row(2, 0.50000) :: Row(3, 0.500000) :: Nil) var statement = runQueryReturnStatement( "select variance(a) from values(1,1),(1,2),(2,1),(2,2),(3,1),(3,2) as T(a,b);", - session - ) + session) val varianceResult = statement.getResultSet while (varianceResult.next()) { @@ -590,8 +534,7 @@ class DataFrameAggregateSuite extends TestData { // add sql test statement = runQueryReturnStatement( "select kurtosis(a) from values(1,1),(1,2),(2,1),(2,2),(3,1),(3,2) as T(a,b);", - session - ) + session) val aggKurtosisResult = statement.getResultSet while (aggKurtosisResult.next()) { @@ -611,10 +554,8 @@ class DataFrameAggregateSuite extends TestData { var_samp($"a"), var_pop($"a"), skew($"a"), - kurtosis($"a") - ), - Seq(Row(null, null, 0.0, null, 0.000000, null, null)) - ) + kurtosis($"a")), + Seq(Row(null, null, 0.0, null, 0.000000, null, null))) checkAnswer( input.agg( @@ -625,10 +566,8 @@ class DataFrameAggregateSuite extends TestData { sqlExpr("var_samp(a)"), sqlExpr("var_pop(a)"), sqlExpr("skew(a)"), - sqlExpr("kurtosis(a)") - ), - Row(null, null, 0.0, null, null, 0.0, null, null) - ) + sqlExpr("kurtosis(a)")), + Row(null, null, 0.0, null, null, 0.0, null, null)) } test("SN - null moments") { @@ -636,8 +575,7 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( emptyTableData .agg(variance($"a"), var_samp($"a"), var_pop($"a"), skew($"a"), kurtosis($"a")), - Seq(Row(null, null, null, null)) - ) + Seq(Row(null, null, null, null))) checkAnswer( emptyTableData.agg( @@ -645,21 +583,17 @@ class DataFrameAggregateSuite extends TestData { sqlExpr("var_samp(a)"), sqlExpr("var_pop(a)"), sqlExpr("skew(a)"), - sqlExpr("kurtosis(a)") - ), - Seq(Row(null, null, null, null, null)) - ) + sqlExpr("kurtosis(a)")), + Seq(Row(null, null, null, null, null))) } test("SN - Decimal sum/avg over window should work.") { checkAnswer( session.sql("select sum(a) over () from values (1.0), (2.0), (3.0) T(a)"), - Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil - ) + Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) checkAnswer( session.sql("select avg(a) over () from values (1.0), (2.0), (3.0) T(a)"), - Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil - ) + Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } test("SN - aggregate function in GROUP BY") { @@ -672,24 +606,20 @@ class DataFrameAggregateSuite extends TestData { test("SN - ints in aggregation expressions are taken as group-by ordinal.") { checkAnswer( testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum($"b")), - Seq(Row(3, 4, 6, 7, 9)) - ) + Seq(Row(3, 4, 6, 7, 9))) checkAnswer( testData2.groupBy(lit(3), lit(4)).agg(lit(6), $"b", sum($"b")), - Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6)) - ) + Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6))) val testdata2Str: String = "(SELECT * FROM VALUES (1,1),(1,2),(2,1),(2,2),(3,1),(3,2) T(a, b) )" checkAnswer( session.sql(s"SELECT 3, 4, SUM(b) FROM $testdata2Str GROUP BY 1, 2"), - Seq(Row(3, 4, 9)) - ) + Seq(Row(3, 4, 9))) checkAnswer( session.sql(s"SELECT 3 AS c, 4 AS d, SUM(b) FROM $testdata2Str GROUP BY c, d"), - Seq(Row(3, 4, 9)) - ) + Seq(Row(3, 4, 9))) } test("distinct and unions") { @@ -734,41 +664,33 @@ class DataFrameAggregateSuite extends TestData { ("b", None), ("b", Some(4)), ("b", Some(5)), - ("b", Some(6)) - ) + ("b", Some(6))) .toDF("x", "y") .createOrReplaceTempView("tempView") checkAnswer( session.sql( "SELECT COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " + - "COUNT_IF(y IS NULL) FROM tempView" - ), - Seq(Row(0L, 3L, 3L, 2L)) - ) + "COUNT_IF(y IS NULL) FROM tempView"), + Seq(Row(0L, 3L, 3L, 2L))) checkAnswer( session.sql( "SELECT x, COUNT_IF(NULL), COUNT_IF(y % 2 = 0), COUNT_IF(y % 2 <> 0), " + - "COUNT_IF(y IS NULL) FROM tempView GROUP BY x" - ), - Row("a", 0L, 1L, 2L, 1L) :: Row("b", 0L, 2L, 1L, 1L) :: Nil - ) + "COUNT_IF(y IS NULL) FROM tempView GROUP BY x"), + Row("a", 0L, 1L, 2L, 1L) :: Row("b", 0L, 2L, 1L, 1L) :: Nil) checkAnswer( session.sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 1"), - Seq(Row("a")) - ) + Seq(Row("a"))) checkAnswer( session.sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y % 2 = 0) = 2"), - Seq(Row("b")) - ) + Seq(Row("b"))) checkAnswer( session.sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(y IS NULL) > 0"), - Row("a") :: Row("b") :: Nil - ) + Row("a") :: Row("b") :: Nil) checkAnswer(session.sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"), Nil) @@ -786,8 +708,7 @@ class DataFrameAggregateSuite extends TestData { Row("dotNET", 2012, 15000.0) :: Row("dotNET", 2013, 48000.0) :: Row("dotNET", null, 63000.0) :: - Row(null, null, 113000.0) :: Nil - ) + Row(null, null, 113000.0) :: Nil) } test("grouping/grouping_id inside window function") { @@ -798,8 +719,8 @@ class DataFrameAggregateSuite extends TestData { .agg( sum($"earnings"), grouping_id($"course", $"year"), - rank().over(Window.partitionBy(grouping_id($"course", $"year")).orderBy(sum($"earnings"))) - ), + rank().over( + Window.partitionBy(grouping_id($"course", $"year")).orderBy(sum($"earnings")))), Row("Java", 2012, 20000.0, 0, 2) :: Row("Java", 2013, 30000.0, 0, 3) :: Row("Java", null, 50000.0, 1, 1) :: @@ -808,8 +729,7 @@ class DataFrameAggregateSuite extends TestData { Row("dotNET", null, 63000.0, 1, 2) :: Row(null, 2012, 35000.0, 2, 1) :: Row(null, 2013, 78000.0, 2, 2) :: - Row(null, null, 113000.0, 3, 1) :: Nil - ) + Row(null, null, 113000.0, 3, 1) :: Nil) } test("References in grouping functions should be indexed with semanticEquals") { @@ -825,8 +745,7 @@ class DataFrameAggregateSuite extends TestData { Row("dotNET", null, 0, 1) :: Row(null, 2012, 1, 0) :: Row(null, 2013, 1, 0) :: - Row(null, null, 1, 1) :: Nil - ) + Row(null, null, 1, 1) :: Nil) } test("agg without groups") { checkAnswer(testData2.agg(sum($"b")), Row(9)) @@ -843,8 +762,7 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( testData3.agg(avg($"b"), sum_distinct($"b")), // non-partial - Row(2.0, 2.0) - ) + Row(2.0, 2.0)) } test("zero average") { @@ -853,8 +771,7 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( emptyTableData.agg(avg($"a"), sum_distinct($"b")), // non-partial - Row(null, null) - ) + Row(null, null)) } test("multiple column distinct count") { val df1 = Seq( @@ -862,8 +779,7 @@ class DataFrameAggregateSuite extends TestData { ("a", "b", "c"), ("a", "b", "d"), ("x", "y", "z"), - ("x", "q", null.asInstanceOf[String]) - ) + ("x", "q", null.asInstanceOf[String])) .toDF("key1", "key2", "key3") checkAnswer(df1.agg(count_distinct($"key1", $"key2")), Row(3)) @@ -872,24 +788,21 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( df1.groupBy($"key1").agg(count_distinct($"key2", $"key3")), - Seq(Row("a", 2), Row("x", 1)) - ) + Seq(Row("a", 2), Row("x", 1))) } test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(count($"a"), sum_distinct($"a")), // non-partial - Row(0, null) - ) + Row(0, null)) } test("zero stddev") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev($"a"), stddev_pop($"a"), stddev_samp($"a")), - Row(null, null, null) - ) + Row(null, null, null)) } test("zero sum") { @@ -915,8 +828,7 @@ class DataFrameAggregateSuite extends TestData { (8, 5, 11, "green", 99), (8, 4, 14, "blue", 99), (8, 3, 21, "red", 99), - (9, 9, 12, "orange", 99) - ).toDF("v1", "v2", "length", "color", "unused") + (9, 9, 12, "orange", 99)).toDF("v1", "v2", "length", "color", "unused") val result = df.groupBy(df.col("color")).agg(listagg(df.col("length"), ",")).collect() // result is unpredictable without within group @@ -924,13 +836,10 @@ class DataFrameAggregateSuite extends TestData { checkAnswer( df.groupBy(df.col("color")) - .agg( - listagg(df.col("length"), ",") - .withinGroup(df.col("length").asc) - ) + .agg(listagg(df.col("length"), ",") + .withinGroup(df.col("length").asc)) .sort($"color"), Seq(Row("blue", "14"), Row("green", "11,77"), Row("orange", "12"), Row("red", "21,24,35")), - sort = false - ) + sort = false) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala index ff31d479..5ca4ca9e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAliasSuite.scala @@ -58,27 +58,23 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes df1 .join(df2, $"id1" === $"id2") .select(df1.col("A.num1")), - Seq(Row(4), Row(5), Row(6)) - ) + Seq(Row(4), Row(5), Row(6))) checkAnswer( df1 .join(df2, $"id1" === $"id2") .select(df2.col("B.num2")), - Seq(Row(7), Row(8), Row(9)) - ) + Seq(Row(7), Row(8), Row(9))) checkAnswer( df1 .join(df2, $"id1" === $"id2") .select($"A.num1"), - Seq(Row(4), Row(5), Row(6)) - ) + Seq(Row(4), Row(5), Row(6))) checkAnswer( df1 .join(df2, $"id1" === $"id2") .select($"B.num2"), - Seq(Row(7), Row(8), Row(9)) - ) + Seq(Row(7), Row(8), Row(9))) } test("Test for alias with join with column renaming") { @@ -92,14 +88,12 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes df1 .join(df2, df1.col("id") === df2.col("id")) .select(df1.col("A.num")), - Seq(Row(4), Row(5), Row(6)) - ) + Seq(Row(4), Row(5), Row(6))) checkAnswer( df1 .join(df2, df1.col("id") === df2.col("id")) .select(df2.col("B.num")), - Seq(Row(7), Row(8), Row(9)) - ) + Seq(Row(7), Row(8), Row(9))) } test("Test for alias conflict") { @@ -110,8 +104,7 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes assertThrows[SnowparkClientException]( df1 .join(df2, df1.col("id") === df2.col("id")) - .select(df1.col("A.num")) - ) + .select(df1.col("A.num"))) } test("snow-1335123") { @@ -119,15 +112,13 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes "col_a", "col_b", "col_c", - "col_d" - ) + "col_d") val df2 = Seq((1, 2, 5, 6), (11, 12, 15, 16), (41, 12, 25, 26), (11, 42, 35, 36)).toDF( "col_a", "col_b", "col_e", - "col_f" - ) + "col_f") val df3 = df1 .alias("a") @@ -135,8 +126,7 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes df2.alias("b"), col("a.col_a") === col("b.col_a") && col("a.col_b") === col("b.col_b"), - "left" - ) + "left") .select("a.col_a", "a.col_b", "col_c", "col_d", "col_e", "col_f") checkAnswer( @@ -145,8 +135,6 @@ class DataFrameAliasSuite extends TestData with BeforeAndAfterEach with EagerSes Row(1, 2, 3, 4, 5, 6), Row(11, 12, 13, 14, 15, 16), Row(11, 32, 33, 34, null, null), - Row(21, 12, 23, 24, null, null) - ) - ) + Row(21, 12, 23, 24, null, null))) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala index 9e7fef44..09e780c6 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala @@ -28,8 +28,7 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer( df.join(df2, "int"), - Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil - ) + Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) } test("join - join using multiple columns") { @@ -38,8 +37,7 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer( df.join(df2, Seq("int", "int2")), - Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil - ) + Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } test("Full outer join followed by inner join") { @@ -66,8 +64,7 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer( df.join(df2), - Seq(Row(1, 1, "test1"), Row(1, 2, "test2"), Row(2, 1, "test1"), Row(2, 2, "test2")) - ) + Seq(Row(1, 1, "test1"), Row(1, 2, "test2"), Row(2, 1, "test1"), Row(2, 2, "test2"))) } test("default inner join with using column") { @@ -92,8 +89,7 @@ trait DataFrameJoinSuite extends SNTestBase { val df2 = Seq(1, 2).map(i => (i, s"num$i")).toDF("num", "val") checkAnswer( df.join(df2, df("a") === df2("num")), - Seq(Row(1, "test1", 1, "num1"), Row(2, "test2", 2, "num2")) - ) + Seq(Row(1, "test1", 1, "num1"), Row(2, "test2", 2, "num2"))) } test("join with multiple conditions") { @@ -125,18 +121,15 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer( df.join(df2, Seq("int", "str"), "left"), - Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null)) - ) + Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null))) checkAnswer( df.join(df2, Seq("int", "str"), "right"), - Seq(Row(1, "1", 2, 3), Row(5, "5", null, 6)) - ) + Seq(Row(1, "1", 2, 3), Row(5, "5", null, 6))) checkAnswer( df.join(df2, Seq("int", "str"), "outer"), - Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null), Row(5, "5", null, 6)) - ) + Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null), Row(5, "5", null, 6))) checkAnswer(df.join(df2, Seq("int", "str"), "left_semi"), Seq(Row(1, 2, "1"))) @@ -183,8 +176,7 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer( df1.naturalJoin(df2, "outer"), - Seq(Row(1, "1", "1"), Row(3, "3", null), Row(4, null, "4")) - ) + Seq(Row(1, "1", "1"), Row(3, "3", null), Row(4, null, "4"))) } test("join - cross join") { @@ -194,14 +186,12 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer( df1.crossJoin(df2), Row(1, "1", 2, "2") :: Row(1, "1", 4, "4") :: - Row(3, "3", 2, "2") :: Row(3, "3", 4, "4") :: Nil - ) + Row(3, "3", 2, "2") :: Row(3, "3", 4, "4") :: Nil) checkAnswer( df2.crossJoin(df1), Row(2, "2", 1, "1") :: Row(2, "2", 3, "3") :: - Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil - ) + Row(4, "4", 1, "1") :: Row(4, "4", 3, "3") :: Nil) } test("join -- ambiguous columns with specified sources") { @@ -210,8 +200,7 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer(df.join(df2, df("a") === df2("a")), Row(1, 1, "test1") :: Row(2, 2, "test2") :: Nil) checkAnswer( df.join(df2, df("a") === df2("a")).select(df("a") * df2("a"), 'b), - Row(1, "test1") :: Row(4, "test2") :: Nil - ) + Row(1, "test1") :: Row(4, "test2") :: Nil) } test("join -- ambiguous columns without specified sources") { @@ -240,8 +229,7 @@ trait DataFrameJoinSuite extends SNTestBase { lhs .join(rhs, lhs("intcol") === rhs("intcol")) .select(lhs("intcol") + rhs("intcol"), lhs("negcol"), rhs("negcol"), 'lhscol, 'rhscol), - Row(2, -1, -10, "one", "one") :: Row(4, -2, -20, "two", "two") :: Nil - ) + Row(2, -1, -10, "one", "one") :: Row(4, -2, -20, "two", "two") :: Nil) } test("Semi joins with absent columns") { @@ -266,34 +254,29 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer( lhs.join(rhs, lhs("intcol") === rhs("intcol"), "leftsemi").select('intcol), - Row(1) :: Row(2) :: Nil - ) + Row(1) :: Row(2) :: Nil) checkAnswer( lhs.join(rhs, lhs("intcol") === rhs("intcol"), "leftsemi").select(lhs("intcol")), - Row(1) :: Row(2) :: Nil - ) + Row(1) :: Row(2) :: Nil) checkAnswer( lhs .join(rhs, lhs("intcol") === rhs("intcol") && lhs("negcol") === rhs("negcol"), "leftsemi") .select(lhs("intcol")), - Nil - ) + Nil) checkAnswer(lhs.join(rhs, lhs("intcol") === rhs("intcol"), "leftanti").select('intcol), Nil) checkAnswer( lhs.join(rhs, lhs("intcol") === rhs("intcol"), "leftanti").select(lhs("intcol")), - Nil - ) + Nil) checkAnswer( lhs .join(rhs, lhs("intcol") === rhs("intcol") && lhs("negcol") === rhs("negcol"), "leftanti") .select(lhs("intcol")), - Row(1) :: Row(2) :: Nil - ) + Row(1) :: Row(2) :: Nil) } test("Using joins") { @@ -303,12 +286,10 @@ trait DataFrameJoinSuite extends SNTestBase { Seq("inner", "leftouter", "rightouter", "full_outer").foreach { joinType => checkAnswer( lhs.join(rhs, Seq("intcol"), joinType).select("*"), - Row(1, -1, "one", -10, "one") :: Row(2, -2, "two", -20, "two") :: Nil - ) + Row(1, -1, "one", -10, "one") :: Row(2, -2, "two", -20, "two") :: Nil) checkAnswer( lhs.join(rhs, Seq("intcol"), joinType), - Row(1, -1, "one", -10, "one") :: Row(2, -2, "two", -20, "two") :: Nil - ) + Row(1, -1, "one", -10, "one") :: Row(2, -2, "two", -20, "two") :: Nil) val ex2 = intercept[SnowparkClientException] { lhs.join(rhs, Seq("intcol"), joinType).select('negcol).collect() @@ -319,8 +300,7 @@ trait DataFrameJoinSuite extends SNTestBase { checkAnswer(lhs.join(rhs, Seq("intcol"), joinType).select('intcol), Row(1) :: Row(2) :: Nil) checkAnswer( lhs.join(rhs, Seq("intcol"), joinType).select(lhs("negcol"), rhs("negcol")), - Row(-1, -10) :: Row(-2, -20) :: Nil - ) + Row(-1, -10) :: Row(-2, -20) :: Nil) } } @@ -332,23 +312,20 @@ trait DataFrameJoinSuite extends SNTestBase { lhs .join(rhs, lhs("intcol") === rhs("intcol")) .select(lhs(""""INTCOL""""), rhs("intcol"), 'doublecol, col(""""DoubleCol"""")), - Row(1, 1, 1.0, 2.0) :: Nil - ) + Row(1, 1, 1.0, 2.0) :: Nil) checkAnswer( lhs .join(rhs, lhs("doublecol") === rhs("\"DoubleCol\"")) .select(lhs(""""INTCOL""""), rhs("intcol"), 'doublecol, rhs(""""DoubleCol"""")), - Nil - ) + Nil) // Below LHS and RHS are swapped but we still default to using the column name as is. checkAnswer( lhs .join(rhs, col("doublecol") === col("\"DoubleCol\"")) .select(lhs(""""INTCOL""""), rhs("intcol"), 'doublecol, col(""""DoubleCol"""")), - Nil - ) + Nil) var ex = intercept[SnowparkClientException] { lhs.join(rhs, col("intcol") === rhs(""""INTCOL"""")).collect() @@ -365,8 +342,7 @@ trait DataFrameJoinSuite extends SNTestBase { .join(rhs, lhs("intcol") === rhs("intcol")) .select((lhs("negcol") + rhs("negcol")) as "newCol", lhs("intcol"), rhs("intcol")) .select(lhs("intcol") + rhs("intcol"), 'newCol), - Row(2, -11) :: Row(4, -22) :: Nil - ) + Row(2, -11) :: Row(4, -22) :: Nil) } test("join - sql as the backing dataframe") { @@ -376,25 +352,21 @@ trait DataFrameJoinSuite extends SNTestBase { val df = session.sql(s"select * from $tableName1 where int2 < 10") val df2 = session.sql( s"select 1 as INT, 3 as INT2, '1' as STR " - + "UNION select 5 as INT, 6 as INT2, '5' as STR" - ) + + "UNION select 5 as INT, 6 as INT2, '5' as STR") checkAnswer(df.join(df2, Seq("int", "str"), "inner"), Seq(Row(1, "1", 2, 3))) checkAnswer( df.join(df2, Seq("int", "str"), "left"), - Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null)) - ) + Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null))) checkAnswer( df.join(df2, Seq("int", "str"), "right"), - Seq(Row(1, "1", 2, 3), Row(5, "5", null, 6)) - ) + Seq(Row(1, "1", 2, 3), Row(5, "5", null, 6))) checkAnswer( df.join(df2, Seq("int", "str"), "outer"), - Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null), Row(5, "5", null, 6)) - ) + Seq(Row(1, "1", 2, 3), Row(3, "3", 4, null), Row(5, "5", null, 6))) val res = df.join(df2, Seq("int", "str"), "left_semi").collect checkAnswer(df.join(df2, Seq("int", "str"), "left_semi"), Seq(Row(1, 2, "1"))) @@ -440,21 +412,18 @@ trait DataFrameJoinSuite extends SNTestBase { // "left" self join checkAnswer( df.join(clonedDF, df("c1") === clonedDF("c2"), "left"), - Seq(Row(1, 2, null, null), Row(2, 3, 1, 2)) - ) + Seq(Row(1, 2, null, null), Row(2, 3, 1, 2))) // "right" self join checkAnswer( df.join(clonedDF, df("c1") === clonedDF("c2"), "right"), - Seq(Row(2, 3, 1, 2), Row(null, null, 2, 3)) - ) + Seq(Row(2, 3, 1, 2), Row(null, null, 2, 3))) // "outer" self join checkAnswer( df.join(clonedDF, df("c1") === clonedDF("c2"), "outer"), Seq(Row(2, 3, 1, 2), Row(null, null, 2, 3), Row(1, 2, null, null)), - false - ) + false) } test("test natural/cross join") { @@ -471,12 +440,10 @@ trait DataFrameJoinSuite extends SNTestBase { // "cross join" supports self join. checkAnswer( df.crossJoin(df2), - Seq(Row(1, 2, 1, 2), Row(1, 2, 2, 3), Row(2, 3, 1, 2), Row(2, 3, 2, 3)) - ) + Seq(Row(1, 2, 1, 2), Row(1, 2, 2, 3), Row(2, 3, 1, 2), Row(2, 3, 2, 3))) checkAnswer( df.crossJoin(clonedDF), - Seq(Row(1, 2, 1, 2), Row(1, 2, 2, 3), Row(2, 3, 1, 2), Row(2, 3, 2, 3)) - ) + Seq(Row(1, 2, 1, 2), Row(1, 2, 2, 3), Row(2, 3, 1, 2), Row(2, 3, 2, 3))) } test("clone with join DataFrame") { @@ -579,8 +546,7 @@ trait DataFrameJoinSuite extends SNTestBase { .join(df2, df1("c") === df2("c")) .drop(df1("b"), df2("b"), df1("c")) .withColumn("newColumn", df1("a") + df2("a")), - Seq(Row(1, 3, true, 4), Row(2, 4, false, 6)) - ) + Seq(Row(1, 3, true, 4), Row(2, 4, false, 6))) } finally { dropTable(tableName1) dropTable(tableName2) @@ -595,8 +561,7 @@ trait DataFrameJoinSuite extends SNTestBase { df1.join(df2, df1("id") === df2("id"), "left_outer").filter(is_null(df2("count"))), Row(2, 0, null, null) :: Row(3, 0, null, null) :: - Row(4, 0, null, null) :: Nil - ) + Row(4, 0, null, null) :: Nil) // Coalesce data using non-nullable columns in input tables val df3 = Seq((1, 1)).toDF("a", "b") @@ -605,8 +570,7 @@ trait DataFrameJoinSuite extends SNTestBase { df3 .join(df4, df3("a") === df4("a"), "outer") .select(coalesce(df3("a"), df3("b")), coalesce(df4("a"), df4("b"))), - Row(1, null) :: Row(null, 2) :: Nil - ) + Row(1, null) :: Row(null, 2) :: Nil) } test("SN: join - outer join conversion") { @@ -642,8 +606,7 @@ trait DataFrameJoinSuite extends SNTestBase { test( "Don't throw Analysis Exception in CheckCartesianProduct when join condition " + - "is false or null" - ) { + "is false or null") { val df = session.range(10).toDF("id") val dfNull = session.range(10).select(lit(null).as("b")) df.join(dfNull, $"id" === $"b", "left").collect() @@ -660,13 +623,11 @@ trait DataFrameJoinSuite extends SNTestBase { runQuery( s"create or replace table $tableTrips " + "(starttime timestamp, start_station_id int, end_station_id int)", - session - ) + session) runQuery( s"create or replace table $tableStations " + "(station_id int, station_name string)", - session - ) + session) val df_trips = session.table(tableTrips) val df_start_stations = session.table(tableStations) @@ -680,8 +641,7 @@ trait DataFrameJoinSuite extends SNTestBase { .select( df_start_stations("station_name"), df_end_stations("station_name"), - df_trips("starttime") - ) + df_trips("starttime")) .collect() } finally { @@ -697,13 +657,11 @@ trait DataFrameJoinSuite extends SNTestBase { runQuery( s"create or replace table $tableTrips " + "(starttime timestamp, \"startid\" int, \"end+station+id\" int)", - session - ) + session) runQuery( s"create or replace table $tableStations " + "(\"station^id\" int, \"station%name\" string)", - session - ) + session) val df_trips = session.table(tableTrips) val df_start_stations = session.table(tableStations) @@ -717,8 +675,7 @@ trait DataFrameJoinSuite extends SNTestBase { .select( df_start_stations("station%name"), df_end_stations("station%name"), - df_trips("starttime") - ) + df_trips("starttime")) .collect() } finally { @@ -754,8 +711,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------------------- ||1 |2 |3 |4 | |------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df.select(df("*")), 10) == """------------------------- @@ -763,8 +719,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------------------- ||1 |2 |3 |4 | |------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df.select(dfLeft("*"), dfRight("*")), 10) == """------------------------- @@ -772,8 +727,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------------------- ||1 |2 |3 |4 | |------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df.select(dfRight("*"), dfLeft("*")), 10) == """------------------------- @@ -781,8 +735,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------------------- ||3 |4 |1 |2 | |------------------------- - |""".stripMargin - ) + |""".stripMargin) } test("select left/right on join result") { @@ -798,8 +751,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------- ||1 |2 | |------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df.select(dfRight("*")), 10) == """------------- @@ -807,8 +759,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------- ||3 |4 | |------------- - |""".stripMargin - ) + |""".stripMargin) } test("select left/right combination on join result") { @@ -824,8 +775,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------------- ||1 |2 |3 | |------------------- - |""".stripMargin - ) + |""".stripMargin) // Select left(*) and left("a") assert( getShowString(df.select(dfLeft("*"), dfLeft("a").as("l_a")), 10) == @@ -834,8 +784,7 @@ trait DataFrameJoinSuite extends SNTestBase { |--------------------- ||1 |2 |1 | |--------------------- - |""".stripMargin - ) + |""".stripMargin) // Select right(*) and right("c") assert( getShowString(df.select(dfRight("*"), dfRight("c").as("R_C")), 10) == @@ -844,8 +793,7 @@ trait DataFrameJoinSuite extends SNTestBase { |--------------------- ||3 |4 |3 | |--------------------- - |""".stripMargin - ) + |""".stripMargin) // Select right(*) and left("a") assert( getShowString(df.select(dfRight("*"), dfLeft("a")), 10) == @@ -854,8 +802,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------------- ||3 |4 |1 | |------------------- - |""".stripMargin - ) + |""".stripMargin) df.select(dfRight("*"), dfRight("c")).show() } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameNonStoredProcSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameNonStoredProcSuite.scala index 99ceab7f..2d83572b 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameNonStoredProcSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameNonStoredProcSuite.scala @@ -14,8 +14,7 @@ class DataFrameNonStoredProcSuite extends TestData { ||1 |2 |2 |2 |2 | ||2 |2 |2 |2 |2 | |--------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(monthlySales.stat.crosstab("month", "empid").sort(col("month")), 10) == @@ -27,8 +26,7 @@ class DataFrameNonStoredProcSuite extends TestData { ||JAN |2 |2 | ||MAR |2 |2 | |------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(date1.stat.crosstab("a", "b").sort(col("a")), 10) == @@ -38,8 +36,7 @@ class DataFrameNonStoredProcSuite extends TestData { ||2010-12-01 |0 |1 | ||2020-08-01 |1 |0 | |---------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(date1.stat.crosstab("b", "a").sort(col("b")), 10) == @@ -49,8 +46,7 @@ class DataFrameNonStoredProcSuite extends TestData { ||1 |1 |0 | ||2 |0 |1 | |----------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(string7.stat.crosstab("a", "b").sort(col("a")), 10) == @@ -60,8 +56,7 @@ class DataFrameNonStoredProcSuite extends TestData { ||NULL |0 |1 | ||str |1 |0 | |---------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(string7.stat.crosstab("b", "a").sort(col("b")), 10) == @@ -71,8 +66,7 @@ class DataFrameNonStoredProcSuite extends TestData { ||1 |1 |0 | ||2 |0 |0 | |-------------------------- - |""".stripMargin - ) + |""".stripMargin) } test("df.stat.pivot") { @@ -80,15 +74,13 @@ class DataFrameNonStoredProcSuite extends TestData { testDataframeStatPivot(), "ENABLE_PIVOT_VIEW_WITH_OBJECT_AGG", "disable", - skipIfParamNotExist = true - ) + skipIfParamNotExist = true) testWithAlteredSessionParameter( testDataframeStatPivot(), "ENABLE_PIVOT_VIEW_WITH_OBJECT_AGG", "enable", - skipIfParamNotExist = true - ) + skipIfParamNotExist = true) } test("ERROR_ON_NONDETERMINISTIC_UPDATE = true") { diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala index 32534963..1735e628 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala @@ -12,8 +12,7 @@ class DataFrameReaderSuite extends SNTestBase { val tmpStageName: String = randomStageName() val tmpStageName2: String = randomStageName() private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) override def beforeAll(): Unit = { super.beforeAll() @@ -68,9 +67,7 @@ class DataFrameReaderSuite extends SNTestBase { Seq( StructField("a", IntegerType), StructField("b", IntegerType), - StructField("c", IntegerType) - ) - ) + StructField("c", IntegerType))) val df2 = reader().schema(incorrectSchema).csv(testFileOnStage) assertThrows[SnowflakeSQLException](df2.collect()) }) @@ -86,9 +83,7 @@ class DataFrameReaderSuite extends SNTestBase { StructField("a", IntegerType), StructField("b", StringType), StructField("c", IntegerType), - StructField("d", IntegerType) - ) - ) + StructField("d", IntegerType))) val df = session.read.schema(incorrectSchema2).csv(testFileOnStage) assert(df.collect() sameElements Array[Row](Row(1, "one", 1, null), Row(2, "two", 2, null))) } @@ -100,9 +95,7 @@ class DataFrameReaderSuite extends SNTestBase { StructField("a", IntegerType), StructField("b", StringType), StructField("c", IntegerType), - StructField("d", IntegerType) - ) - ) + StructField("d", IntegerType))) val df = session.read.option("purge", false).schema(incorrectSchema2).csv(testFileOnStage) // throw exception from COPY assertThrows[SnowflakeSQLException](df.collect()) @@ -125,9 +118,7 @@ class DataFrameReaderSuite extends SNTestBase { Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), - Row(2, "two", 2.2) - ) - ) + Row(2, "two", 2.2))) // test for union between two stages val testFileOnStage2 = s"@$tmpStageName2/$testFileCsv" @@ -140,9 +131,7 @@ class DataFrameReaderSuite extends SNTestBase { Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), - Row(2, "two", 2.2) - ) - ) + Row(2, "two", 2.2))) } testReadFile("read csv with formatTypeOptions")(reader => { @@ -175,9 +164,7 @@ class DataFrameReaderSuite extends SNTestBase { Row(1, "one", 1.2), Row(2, "two", 2.2), Row(1, "one", 1.2), - Row(2, "two", 2.2) - ) - ) + Row(2, "two", 2.2))) val df6 = df1.union(df4) assert(df6.collect() sameElements Array[Row](Row(1, "one", 1.2), Row(2, "two", 2.2))) @@ -197,8 +184,7 @@ class DataFrameReaderSuite extends SNTestBase { .csv(s"@$dataFilesStage/") checkAnswer( df, - Seq(Row(1, "one", 1.2), Row(2, "two", 2.2), Row(3, "three", 3.3), Row(4, "four", 4.4)) - ) + Seq(Row(1, "one", 1.2), Row(2, "two", 2.2), Row(3, "three", 3.3), Row(4, "four", 4.4))) } finally { runQuery(s"DROP STAGE IF EXISTS $dataFilesStage", session) } @@ -226,8 +212,7 @@ class DataFrameReaderSuite extends SNTestBase { .sql( s"copy into $path from" + s" ( select * from $tmpTable) file_format=(format_name='$formatName'" + - s" compression='$ctype')" - ) + s" compression='$ctype')") .collect() // Read the data @@ -236,8 +221,7 @@ class DataFrameReaderSuite extends SNTestBase { .option("COMPRESSION", ctype) .schema(userSchema) .csv(path), - result - ) + result) }) } finally { runQuery(s"drop file format $formatName", session) @@ -250,9 +234,7 @@ class DataFrameReaderSuite extends SNTestBase { StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType), - StructField("d", IntegerType) - ) - ) + StructField("d", IntegerType))) val testFile = s"@$tmpStageName/$testFileCsvQuotes" val df1 = reader() .schema(schema1) @@ -273,9 +255,7 @@ class DataFrameReaderSuite extends SNTestBase { StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType), - StructField("d", StringType) - ) - ) + StructField("d", StringType))) val df3 = reader().schema(schema2).csv(testFile) val res = df3.collect() checkAnswer(df3.select("d"), Seq(Row("\"1\""), Row("\"2\""))) @@ -287,14 +267,12 @@ class DataFrameReaderSuite extends SNTestBase { checkAnswer( df1, - Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}")) - ) + Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}"))) // query test checkAnswer( df1.where(sqlExpr("$1:color") === "Red"), - Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}")) - ) + Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}"))) // assert user cannot input a schema to read json assertThrows[IllegalArgumentException](reader().schema(userSchema).json(jsonPath)) @@ -303,8 +281,7 @@ class DataFrameReaderSuite extends SNTestBase { val df2 = reader().option("FILE_EXTENSION", "json").json(jsonPath) checkAnswer( df2, - Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}")) - ) + Seq(Row("{\n \"color\": \"Red\",\n \"fruit\": \"Apple\",\n \"size\": \"Large\"\n}"))) }) testReadFile("read avro with no schema")(reader => { @@ -314,16 +291,13 @@ class DataFrameReaderSuite extends SNTestBase { df1, Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ), - sort = false - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), + sort = false) // query test checkAnswer( df1.where(sqlExpr("$1:num") > 1), - Seq(Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")) - ) + Seq(Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}"))) // assert user cannot input a schema to read avro assertThrows[IllegalArgumentException](session.read.schema(userSchema).avro(avroPath)) @@ -334,10 +308,8 @@ class DataFrameReaderSuite extends SNTestBase { df2, Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ), - sort = false - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), + sort = false) }) testReadFile("Test for all parquet compression types")(reader => { @@ -354,8 +326,7 @@ class DataFrameReaderSuite extends SNTestBase { .sql( s"copy into @$tmpStageName/$ctype/ from" + s" ( select * from $tmpTable) file_format=(format_name='$formatName'" + - s" compression='$ctype') overwrite = true" - ) + s" compression='$ctype') overwrite = true") .collect() // Read the data @@ -373,17 +344,14 @@ class DataFrameReaderSuite extends SNTestBase { df1, Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ), - sort = false - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), + sort = false) // query test checkAnswer( df1.where(sqlExpr("$1:num") > 1), Seq(Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), - sort = false - ) + sort = false) // assert user cannot input a schema to read parquet assertThrows[IllegalArgumentException](session.read.schema(userSchema).parquet(path)) @@ -394,10 +362,8 @@ class DataFrameReaderSuite extends SNTestBase { df2, Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ), - sort = false - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), + sort = false) }) testReadFile("read orc with no schema")(reader => { @@ -407,17 +373,14 @@ class DataFrameReaderSuite extends SNTestBase { df1, Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ), - sort = false - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), + sort = false) // query test checkAnswer( df1.where(sqlExpr("$1:num") > 1), Seq(Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), - sort = false - ) + sort = false) // assert user cannot input a schema to read avro assertThrows[IllegalArgumentException](session.read.schema(userSchema).orc(path)) @@ -428,10 +391,8 @@ class DataFrameReaderSuite extends SNTestBase { df2, Seq( Row("{\n \"num\": 1,\n \"str\": \"str1\"\n}"), - Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}") - ), - sort = false - ) + Row("{\n \"num\": 2,\n \"str\": \"str2\"\n}")), + sort = false) }) testReadFile("read xml with no schema")(reader => { @@ -441,18 +402,15 @@ class DataFrameReaderSuite extends SNTestBase { df1, Seq( Row("\n 1\n str1\n"), - Row("\n 2\n str2\n") - ), - sort = false - ) + Row("\n 2\n str2\n")), + sort = false) // query test checkAnswer( df1 .where(sqlExpr("xmlget($1, 'num', 0):\"$\"") > 1), Seq(Row("\n 2\n str2\n")), - sort = false - ) + sort = false) // assert user cannot input a schema to read avro assertThrows[IllegalArgumentException](session.read.schema(userSchema).xml(path)) @@ -463,10 +421,8 @@ class DataFrameReaderSuite extends SNTestBase { df2, Seq( Row("\n 1\n str1\n"), - Row("\n 2\n str2\n") - ), - sort = false - ) + Row("\n 2\n str2\n")), + sort = false) }) test("read file on_error = continue on CSV") { @@ -479,8 +435,7 @@ class DataFrameReaderSuite extends SNTestBase { .option("on_error", "continue") .option("COMPRESSION", "none") .csv(brokenFile), - Seq(Row(1, "one", 1.1), Row(3, "three", 3.3)) - ) + Seq(Row(1, "one", 1.1), Row(3, "three", 3.3))) } @@ -493,8 +448,7 @@ class DataFrameReaderSuite extends SNTestBase { .option("on_error", "continue") .option("COMPRESSION", "none") .avro(brokenFile), - Seq.empty - ) + Seq.empty) } test("SELECT and COPY on non CSV format have same result schema") { @@ -530,8 +484,7 @@ class DataFrameReaderSuite extends SNTestBase { .option("COMPRESSION", "none") .option("pattern", ".*CSV[.]csv") .csv(s"@$tmpStageName") - .count() == 4 - ) + .count() == 4) }) testReadFile("table function on csv dataframe reader test")(reader => { @@ -542,13 +495,11 @@ class DataFrameReaderSuite extends SNTestBase { session .tableFunction(TableFunction("split_to_table"), Seq(df1("b"), lit(" "))) .select("VALUE"), - Seq(Row("one"), Row("two")) - ) + Seq(Row("one"), Row("two"))) }) def testReadFile(testName: String, testTags: Tag*)( - thunk: (() => DataFrameReader) => Unit - ): Unit = { + thunk: (() => DataFrameReader) => Unit): Unit = { // test select test(testName + " - SELECT", testTags: _*) { thunk(() => session.read) diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala index 17bce597..7f4c5a04 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala @@ -37,8 +37,7 @@ class DataFrameSetOperationsSuite extends TestData { check( lit(null).cast(IntegerType), $"c".is_null, - Seq(Row(1, 1, null, 100), Row(1, 1, null, 100)) - ) + Seq(Row(1, 1, null, 100), Row(1, 1, null, 100))) check(lit(null).cast(IntegerType), $"c".is_not_null, Seq()) check(lit(2).cast(IntegerType), $"c".is_null, Seq()) check(lit(2).cast(IntegerType), $"c".is_not_null, Seq(Row(1, 1, 2, 100), Row(1, 1, 2, 100))) @@ -52,8 +51,7 @@ class DataFrameSetOperationsSuite extends TestData { Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: - Row(4, "d") :: Nil - ) + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.except(lowerCaseData), Nil) checkAnswer(upperCaseData.except(upperCaseData), Nil) @@ -71,8 +69,7 @@ class DataFrameSetOperationsSuite extends TestData { df.except(df.filter(lit(0) === 1)), Row("id", 1) :: Row("id1", 1) :: - Row("id1", 2) :: Nil - ) + Row("id1", 2) :: Nil) // check if the empty set on the left side works checkAnswer(allNulls.filter(lit(0) === 1).except(allNulls), Nil) @@ -249,8 +246,7 @@ class DataFrameSetOperationsSuite extends TestData { test("SN - Performing set operations that combine non-scala native types") { val dates = Seq( (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)), - (new Date(3), BigDecimal.valueOf(4), new Timestamp(5)) - ).toDF("date", "decimal", "timestamp") + (new Date(3), BigDecimal.valueOf(4), new Timestamp(5))).toDF("date", "decimal", "timestamp") val widenTypedRows = Seq((new Timestamp(2), 10.5d, (new Timestamp(10)).toString)) @@ -301,8 +297,7 @@ class DataFrameSetOperationsSuite extends TestData { Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: - Row(4, "d") :: Nil - ) + Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) // check null equality @@ -311,8 +306,7 @@ class DataFrameSetOperationsSuite extends TestData { Row(1) :: Row(2) :: Row(3) :: - Row(null) :: Nil - ) + Row(null) :: Nil) // check if values are de-duplicated checkAnswer(allNulls.intersect(allNulls), Row(null) :: Nil) @@ -323,8 +317,7 @@ class DataFrameSetOperationsSuite extends TestData { df.intersect(df), Row("id", 1) :: Row("id1", 1) :: - Row("id1", 2) :: Nil - ) + Row("id1", 2) :: Nil) } test("Project should not be pushed down through Intersect or Except") { @@ -373,8 +366,7 @@ class DataFrameSetOperationsSuite extends TestData { checkAnswer( df1.union(df2).intersect(df2.union(df3)).union(df3), df2.union(df3).collect(), - sort = false - ) + sort = false) checkAnswer(df1.union(df2).except(df2.union(df3).intersect(df1.union(df2))), df1.collect()) } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala index ab2051ab..fed14421 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala @@ -40,8 +40,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { checkAnswer( df.sort(col("b").asc_nulls_last), Seq(Row(2, "NotNull"), Row(1, null), Row(3, null)), - false - ) + false) } test("Project null values") { @@ -82,8 +81,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ||1 |NULL | ||2 |N... | |-------------- - |""".stripMargin - ) + |""".stripMargin) } test("show with null data") { @@ -99,8 +97,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ||1 |NULL | ||2 |NotNull | |----------------- - |""".stripMargin - ) + |""".stripMargin) } test("show multi-lines row") { @@ -117,8 +114,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { || |one more line | || |last line | |------------------------------- - |""".stripMargin - ) + |""".stripMargin) } test("show") { @@ -133,8 +129,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ||1 |true |a | ||2 |false |b | |-------------------------- - |""".stripMargin - ) + |""".stripMargin) session.sql("show tables").show() @@ -147,8 +142,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { |------------------------------------------------------ ||Drop statement executed successfully (TEST_TABL... | |------------------------------------------------------ - |""".stripMargin - ) + |""".stripMargin) } test("cacheResult") { @@ -162,8 +156,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { testCacheResult(), "snowpark_use_scoped_temp_objects", "true", - skipIfParamNotExist = true - ) + skipIfParamNotExist = true) } private def testCacheResult(): Unit = { @@ -256,8 +249,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { .select(""""name"""") .filter(col(""""name"""") === tableName) .collect() - .length == 1 - ) + .length == 1) } finally { runQuery(s"drop table if exists $tableName", session) } @@ -275,8 +267,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { .select(""""name"""") .filter(col(""""name"""") === tableName) .collect() - .length == 2 - ) + .length == 2) } finally { runQuery(s"drop table if exists $tableName", session) } @@ -296,8 +287,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { checkAnswer(double3.na.drop(1, Seq("a")), Seq(Row(1.0, 1), Row(4.0, null))) checkAnswer( double3.na.drop(1, Seq("a", "b")), - Seq(Row(1.0, 1), Row(4.0, null), Row(Double.NaN, 2), Row(null, 3)) - ) + Seq(Row(1.0, 1), Row(4.0, null), Row(Double.NaN, 2), Row(null, 3))) assert(double3.na.drop(0, Seq("a")).count() == 6) assert(double3.na.drop(3, Seq("a", "b")).count() == 0) assert(double3.na.drop(1, Seq()).count() == 6) @@ -313,10 +303,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row(12.3, 3, false, "f"), Row(4.0, 11, false, "d"), Row(12.3, 11, false, "f"), - Row(12.3, 11, false, "f") - ), - sort = false - ) + Row(12.3, 11, false, "f")), + sort = false) checkAnswer( nullData3.na.fill(Map("flo" -> 22.3f, "int" -> 22L, "boo" -> false, "str" -> "f")), Seq( @@ -325,38 +313,30 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row(22.3, 3, false, "f"), Row(4.0, 22, false, "d"), Row(22.3, 22, false, "f"), - Row(22.3, 22, false, "f") - ), - sort = false - ) + Row(22.3, 22, false, "f")), + sort = false) checkAnswer( nullData3.na.fill( - Map("flo" -> 12.3, "int" -> 33.asInstanceOf[Short], "boo" -> false, "str" -> "f") - ), + Map("flo" -> 12.3, "int" -> 33.asInstanceOf[Short], "boo" -> false, "str" -> "f")), Seq( Row(1.0, 1, true, "a"), Row(12.3, 2, false, "b"), Row(12.3, 3, false, "f"), Row(4.0, 33, false, "d"), Row(12.3, 33, false, "f"), - Row(12.3, 33, false, "f") - ), - sort = false - ) + Row(12.3, 33, false, "f")), + sort = false) checkAnswer( nullData3.na.fill( - Map("flo" -> 12.3, "int" -> 44.asInstanceOf[Byte], "boo" -> false, "str" -> "f") - ), + Map("flo" -> 12.3, "int" -> 44.asInstanceOf[Byte], "boo" -> false, "str" -> "f")), Seq( Row(1.0, 1, true, "a"), Row(12.3, 2, false, "b"), Row(12.3, 3, false, "f"), Row(4.0, 44, false, "d"), Row(12.3, 44, false, "f"), - Row(12.3, 44, false, "f") - ), - sort = false - ) + Row(12.3, 44, false, "f")), + sort = false) // wrong type checkAnswer( nullData3.na.fill(Map("flo" -> 12.3, "int" -> "11", "boo" -> false, "str" -> 1)), @@ -366,10 +346,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row(12.3, 3, false, null), Row(4.0, null, false, "d"), Row(12.3, null, false, null), - Row(12.3, null, false, null) - ), - sort = false - ) + Row(12.3, null, false, null)), + sort = false) // wrong column name assertThrows[SnowparkClientException](nullData3.na.fill(Map("wrong" -> 11))) @@ -384,10 +362,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row(null, 3, false, null), Row(4.0, null, null, "d"), Row(null, null, null, null), - Row(Double.NaN, null, null, null) - ), - sort = false - ) + Row(Double.NaN, null, null, null)), + sort = false) // replace null checkAnswer( nullData3.na.replace("boo", Map(None -> true)), @@ -397,10 +373,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row(null, 3, false, null), Row(4.0, null, true, "d"), Row(null, null, true, null), - Row(Double.NaN, null, true, null) - ), - sort = false - ) + Row(Double.NaN, null, true, null)), + sort = false) // replace NaN checkAnswer( @@ -411,10 +385,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row(null, 3, false, null), Row(4.0, null, null, "d"), Row(null, null, null, null), - Row(11, null, null, null) - ), - sort = false - ) + Row(11, null, null, null)), + sort = false) // incompatible type assertThrows[SnowflakeSQLException](nullData3.na.replace("flo", Map(None -> "aa")).collect()) @@ -428,10 +400,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row(null, 3, false, null), Row(4.0, null, null, "d"), Row(null, null, null, null), - Row(null, null, null, null) - ), - sort = false - ) + Row(null, null, null, null)), + sort = false) assert( getSchemaString(nullData3.na.replace("flo", Map(Double.NaN -> null)).schema) == @@ -440,8 +410,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | |--INT: Long (nullable = true) | |--BOO: Boolean (nullable = true) | |--STR: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) } @@ -478,8 +447,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert(approxNumbers.stat.approxQuantile("a", Array(0.5))(0).get == 4.5) assert( approxNumbers.stat.approxQuantile("a", Array(0, 0.1, 0.4, 0.6, 1)).deep == - Array(Some(0.0), Some(0.9), Some(3.6), Some(5.3999999999999995), Some(9.0)).deep - ) + Array(Some(0.0), Some(0.9), Some(3.6), Some(5.3999999999999995), Some(9.0)).deep) // Probability out of range error and apply on string column error. assertThrows[SnowflakeSQLException](approxNumbers.stat.approxQuantile("a", Array(-1d))) @@ -516,8 +484,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ||2 |800 |APR | ||2 |4500 |APR | |-------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(monthlySales.stat.sampleBy(col("month"), Map("JAN" -> 1.0)), 10) == @@ -529,8 +496,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ||2 |4500 |JAN | ||2 |35000 |JAN | |-------------------------------- - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(monthlySales.stat.sampleBy(col("month"), Map()), 10) == @@ -538,8 +504,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ||"EMPID" |"AMOUNT" |"MONTH" | |-------------------------------- |-------------------------------- - |""".stripMargin - ) + |""".stripMargin) } // On GitHub Action this test time out. But locally it passed. @@ -573,8 +538,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { |----------------------------------- ||1 |1000 | |----------------------------------- - |""".stripMargin - ) + |""".stripMargin) val df4 = Seq .fill(1001) { @@ -588,8 +552,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { |----------------------------------- ||1 |1001 | |----------------------------------- - |""".stripMargin - ) + |""".stripMargin) } test("select *") { @@ -636,8 +599,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { val halfRowCount = Math.max(rowCount * 0.5, 1) assert( Math.abs(df1.sample(0.50).count() - halfRowCount) < - halfRowCount * samplingDeviation - ) + halfRowCount * samplingDeviation) // Sample all rows assert(df1.sample(1.0).count() == rowCount) } @@ -665,8 +627,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert(result.sample(sampleRowCount).count() == sampleRowCount) assert( Math.abs(result.sample(0.10).count() - sampleRowCount) < - sampleRowCount * samplingDeviation - ) + sampleRowCount * samplingDeviation) } test("sample() on union") { @@ -680,16 +641,14 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert(result.sample(sampleRowCount).count() == sampleRowCount) assert( Math.abs(result.sample(0.10).count() - sampleRowCount) < - sampleRowCount * samplingDeviation - ) + sampleRowCount * samplingDeviation) // Test union all result = df1.unionAll(df2) sampleRowCount = Math.max(result.count() / 10, 1) assert(result.sample(sampleRowCount).count() == sampleRowCount) assert( Math.abs(result.sample(0.10).count() - sampleRowCount) < - sampleRowCount * samplingDeviation - ) + sampleRowCount * samplingDeviation) } test("randomSplit()") { @@ -703,8 +662,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { weights: Array[Double], index: Int, count: Long, - TotalCount: Long - ): Unit = { + TotalCount: Long): Unit = { val expectedRowCount = TotalCount * weights(index) / weights.sum assert(Math.abs(expectedRowCount - count) < expectedRowCount * samplingDeviation) } @@ -831,8 +789,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert( sortedRows(i - 1).getLong(0) < sortedRows(i).getLong(0) || (sortedRows(i - 1).getLong(0) == sortedRows(i).getLong(0) && - sortedRows(i - 1).getLong(1) <= sortedRows(i).getLong(1)) - ) + sortedRows(i - 1).getLong(1) <= sortedRows(i).getLong(1))) } // order DESC with 2 column sortedRows = df.sort(col("a").desc, col("b").desc).collect() @@ -840,8 +797,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert( sortedRows(i - 1).getLong(0) > sortedRows(i).getLong(0) || (sortedRows(i - 1).getLong(0) == sortedRows(i).getLong(0) && - sortedRows(i - 1).getLong(1) >= sortedRows(i).getLong(1)) - ) + sortedRows(i - 1).getLong(1) >= sortedRows(i).getLong(1))) } // Negative test: sort() needs at least one sort expression. @@ -978,8 +934,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ("country B", "state A", 100), ("country B", "state A", 100), ("country B", "state B", 10), - ("country B", "state B", 10) - ) + ("country B", "state B", 10)) .toDF("country", "state", "value") // At least one column needs to be provided ( negative test ) @@ -1005,36 +960,31 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row("country A", "state B", 10), Row("country B", null, 220), Row("country B", "state A", 200), - Row("country B", "state B", 20) - ) + Row("country B", "state B", 20)) checkAnswer( df.rollup("country", "state") .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) checkAnswer( df.rollup(Seq("country", "state")) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) checkAnswer( df.rollup(col("country"), col("state")) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) checkAnswer( df.rollup(Seq(col("country"), col("state"))) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) } test("groupBy()") { @@ -1046,8 +996,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ("country B", "state A", 100), ("country B", "state A", 100), ("country B", "state B", 10), - ("country B", "state B", 10) - ) + ("country B", "state B", 10)) .toDF("country", "state", "value") // groupBy() without column @@ -1067,28 +1016,23 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row("country A", "state A", 100), Row("country A", "state B", 10), Row("country B", "state A", 200), - Row("country B", "state B", 20) - ) + Row("country B", "state B", 20)) checkAnswer( df.groupBy("country", "state") .agg(sum(col("value"))), - expectedResult - ) + expectedResult) checkAnswer( df.groupBy(Seq("country", "state")) .agg(sum(col("value"))), - expectedResult - ) + expectedResult) checkAnswer( df.groupBy(col("country"), col("state")) .agg(sum(col("value"))), - expectedResult - ) + expectedResult) checkAnswer( df.groupBy(Seq(col("country"), col("state"))) .agg(sum(col("value"))), - expectedResult - ) + expectedResult) } test("cube()") { @@ -1100,8 +1044,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ("country B", "state A", 100), ("country B", "state A", 100), ("country B", "state B", 10), - ("country B", "state B", 10) - ) + ("country B", "state B", 10)) .toDF("country", "state", "value") // At least one column needs to be provided ( negative test ) @@ -1130,36 +1073,31 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row("country A", "state B", 10), Row("country B", null, 220), Row("country B", "state A", 200), - Row("country B", "state B", 20) - ) + Row("country B", "state B", 20)) checkAnswer( df.cube("country", "state") .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) checkAnswer( df.cube(Seq("country", "state")) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) checkAnswer( df.cube(col("country"), col("state")) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) checkAnswer( df.cube(Seq(col("country"), col("state"))) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - false - ) + false) } test("flatten") { @@ -1172,8 +1110,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { checkAnswer( flatten.select(table1("value"), flatten("value")), Seq(Row("[\n 1,\n 2\n]", "1"), Row("[\n 1,\n 2\n]", "2")), - sort = false - ) + sort = false) // multiple flatten val flatten1 = @@ -1181,13 +1118,11 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { checkAnswer( flatten1.select(table1("value"), flatten("value"), flatten1("value")), Seq(Row("[\n 1,\n 2\n]", "1", "1"), Row("[\n 1,\n 2\n]", "2", "1")), - sort = false - ) + sort = false) // wrong mode assertThrows[SnowparkClientException]( - flatten.flatten(col("value"), "", outer = false, recursive = false, "wrong") - ) + flatten.flatten(col("value"), "", outer = false, recursive = false, "wrong")) // contains multiple query val df = session.sql("show tables").limit(1) @@ -1210,34 +1145,29 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { checkAnswer( df2.join(df3, df2("value") === df3("value")).select(df3("value")), Seq(Row("1"), Row("2")), - sort = false - ) + sort = false) // union checkAnswer( df2.union(df3).select(col("value")), Seq(Row("1"), Row("2"), Row("1"), Row("2")), - sort = false - ) + sort = false) } test("flatten in session") { checkAnswer( session.flatten(parse_json(lit("""["a","'"]"""))).select(col("value")), Seq(Row("\"a\""), Row("\"'\"")), - sort = false - ) + sort = false) checkAnswer( session .flatten(parse_json(lit("""{"a":[1,2]}""")), "a", outer = true, recursive = true, "ARRAY") .select("value"), - Seq(Row("1"), Row("2")) - ) + Seq(Row("1"), Row("2"))) assertThrows[SnowparkClientException]( - session.flatten(parse_json(lit("[1]")), "", outer = false, recursive = false, "wrong") - ) + session.flatten(parse_json(lit("[1]")), "", outer = false, recursive = false, "wrong")) val df1 = session.flatten(parse_json(lit("[1,2]"))) val df2 = @@ -1246,15 +1176,13 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { "a", outer = false, recursive = false, - "BOTH" - ) + "BOTH") // union checkAnswer( df1.union(df2).select("path"), Seq(Row("[0]"), Row("[1]"), Row("a[0]"), Row("a[1]")), - sort = false - ) + sort = false) // join checkAnswer( @@ -1262,8 +1190,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { .join(df2, df1("value") === df2("value")) .select(df1("path").as("path1"), df2("path").as("path2")), Seq(Row("[0]", "a[0]"), Row("[1]", "a[1]")), - sort = false - ) + sort = false) } @@ -1281,9 +1208,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { StructField("boolean", BooleanType), StructField("binary", BinaryType), StructField("timestamp", TimestampType), - StructField("date", DateType) - ) - ) + StructField("date", DateType))) val timestamp: Long = 1606179541282L val data = Seq( @@ -1299,10 +1224,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { true, Array(1.toByte, 2.toByte), new Timestamp(timestamp - 100), - new Date(timestamp - 100) - ), - Row(null, null, null, null, null, null, null, null, null, null, null, null) - ) + new Date(timestamp - 100)), + Row(null, null, null, null, null, null, null, null, null, null, null, null)) val result = session.createDataFrame(data, schema) // byte, short, int, long are converted to long @@ -1323,8 +1246,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | |--BINARY: Binary (nullable = true) | |--TIMESTAMP: Timestamp (nullable = true) | |--DATE: Date (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(result, data, sort = false) } @@ -1339,8 +1261,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { getSchemaString(df.schema) === """root | |--TIME: Time (nullable = true) - |""".stripMargin - ) + |""".stripMargin) } // In the result, Array, Map and Geography are String data @@ -1351,19 +1272,15 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { StructField("map", MapType(null, null)), StructField("variant", VariantType), StructField("geography", GeographyType), - StructField("geometry", GeometryType) - ) - ) + StructField("geometry", GeometryType))) val data = Seq( Row( Array("'", 2), Map("'" -> 1), new Variant(1), Geography.fromGeoJSON("POINT(30 10)"), - Geometry.fromGeoJSON("POINT(20 40)") - ), - Row(null, null, null, null, null) - ) + Geometry.fromGeoJSON("POINT(20 40)")), + Row(null, null, null, null, null)) val df = session.createDataFrame(data, schema) assert( @@ -1374,8 +1291,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | |--VARIANT: Variant (nullable = true) | |--GEOGRAPHY: Geography (nullable = true) | |--GEOMETRY: Geometry (nullable = true) - |""".stripMargin - ) + |""".stripMargin) df.show() val expected = Seq( @@ -1396,17 +1312,14 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | 4.000000000000000e+01 | ], | "type": "Point" - |}""".stripMargin) - ), - Row(null, null, null, null, null) - ) + |}""".stripMargin)), + Row(null, null, null, null, null)) checkAnswer(df, expected, sort = false) } test("variant in array and map") { val schema = StructType( - Seq(StructField("array", ArrayType(null)), StructField("map", MapType(null, null))) - ) + Seq(StructField("array", ArrayType(null)), StructField("map", MapType(null, null)))) val data = Seq(Row(Array(new Variant(1), new Variant("\"'")), Map("a" -> new Variant("\"'")))) val df = session.createDataFrame(data, schema) checkAnswer(df, Seq(Row("[\n 1,\n \"\\\"'\"\n]", "{\n \"a\": \"\\\"'\"\n}"))) @@ -1418,38 +1331,24 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row( Array( Geography.fromGeoJSON("point(30 10)"), - Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}") - ) - ) - ) + Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}")))) checkAnswer( session.createDataFrame(data, schema), Seq( - Row( - "[\n \"point(30 10)\",\n {\n \"coordinates\": [\n" + - " 30,\n 10\n ],\n \"type\": \"Point\"\n }\n]" - ) - ) - ) + Row("[\n \"point(30 10)\",\n {\n \"coordinates\": [\n" + + " 30,\n 10\n ],\n \"type\": \"Point\"\n }\n]"))) val schema1 = StructType(Seq(StructField("map", MapType(null, null)))) val data1 = Seq( Row( Map( "a" -> Geography.fromGeoJSON("point(30 10)"), - "b" -> Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}") - ) - ) - ) + "b" -> Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}")))) checkAnswer( session.createDataFrame(data1, schema1), Seq( - Row( - "{\n \"a\": \"point(30 10)\",\n \"b\": {\n \"coordinates\": [\n" + - " 30,\n 10\n ],\n \"type\": \"Point\"\n }\n}" - ) - ) - ) + Row("{\n \"a\": \"point(30 10)\",\n \"b\": {\n \"coordinates\": [\n" + + " 30,\n 10\n ],\n \"type\": \"Point\"\n }\n}"))) } test("escaped character") { @@ -1511,10 +1410,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Table1( new Variant(1), Geography.fromGeoJSON("point(10 10)"), - Geometry.fromGeoJSON("point(20 40)") - ) - ) - ) + Geometry.fromGeoJSON("point(20 40)")))) df3.schema.printTreeString() checkAnswer( df3, @@ -1534,10 +1430,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | 4.000000000000000e+01 | ], | "type": "Point" - |}""".stripMargin) - ) - ) - ) + |}""".stripMargin)))) } case class Table1(variant: Variant, geography: Geography, geometry: Geometry) @@ -1551,8 +1444,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | |--A: Long (nullable = false) | |--B: Long (nullable = false) | |--C: Boolean (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df, Seq(Row(1, 1, null), Row(2, 3, true)), sort = false) } @@ -1564,27 +1456,21 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert( getSchemaString( session - .createDataFrame( - Seq( - (Some(Array(1.toByte, 2.toByte)), Array(3.toByte, 4.toByte)), - (None, Array.empty[Byte]) - ) - ) - .schema - ) == + .createDataFrame(Seq( + (Some(Array(1.toByte, 2.toByte)), Array(3.toByte, 4.toByte)), + (None, Array.empty[Byte]))) + .schema) == """root | |--_1: Binary (nullable = true) | |--_2: Binary (nullable = false) - |""".stripMargin - ) + |""".stripMargin) } test("primitive array") { checkAnswer( session .createDataFrame(Seq(Row(Array(1))), StructType(Seq(StructField("arr", ArrayType(null))))), - Seq(Row("[\n 1\n]")) - ) + Seq(Row("[\n 1\n]"))) } test("time, date and timestamp test") { @@ -1594,11 +1480,9 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { .sql("select '1970-1-1 00:00:00' :: Timestamp") .collect()(0) .getTimestamp(0) - .toString == "1970-01-01 00:00:00.0" - ) + .toString == "1970-01-01 00:00:00.0") assert( - session.sql("select '1970-1-1' :: Date").collect()(0).getDate(0).toString == "1970-01-01" - ) + session.sql("select '1970-1-1' :: Date").collect()(0).getDate(0).toString == "1970-01-01") } test("quoted column names") { @@ -1612,8 +1496,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { createTable( tableName, s"$normalName int, $lowerCaseName int, $quoteStart int," + - s"$quoteEnd int, $quoteMiddle int, $quoteAllCases int" - ) + s"$quoteEnd int, $quoteMiddle int, $quoteAllCases int") runQuery(s"insert into $tableName values(1, 2, 3, 4, 5, 6)", session) // Test select() @@ -1628,8 +1511,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { schema1.fields(2).name.equals(quoteStart) && schema1.fields(3).name.equals(quoteEnd) && schema1.fields(4).name.equals(quoteMiddle) && - schema1.fields(5).name.equals(quoteAllCases) - ) + schema1.fields(5).name.equals(quoteAllCases)) checkAnswer(df1, Seq(Row(1, 2, 3, 4, 5, 6))) // Test select() + cacheResult() + select() @@ -1646,8 +1528,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { schema2.fields(2).name.equals(quoteStart) && schema2.fields(3).name.equals(quoteEnd) && schema2.fields(4).name.equals(quoteMiddle) && - schema2.fields(5).name.equals(quoteAllCases) - ) + schema2.fields(5).name.equals(quoteAllCases)) checkAnswer(df2, Seq(Row(1, 2, 3, 4, 5, 6))) // Test drop() @@ -1680,8 +1561,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { createTable( tableName, s"$normalName int, $lowerCaseName int, $quoteStart int," + - s"$quoteEnd int, $quoteMiddle int, $quoteAllCases int" - ) + s"$quoteEnd int, $quoteMiddle int, $quoteAllCases int") runQuery(s"insert into $tableName values(1, 2, 3, 4, 5, 6)", session) // Test simplified input format @@ -1697,8 +1577,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { schema1.fields.length == 3 && schema1.fields(0).name.equals(quoteStart) && schema1.fields(1).name.equals(quoteEnd) && - schema1.fields(2).name.equals(quoteMiddle) - ) + schema1.fields(2).name.equals(quoteMiddle)) checkAnswer(df1, Seq(Row(3, 4, 5))) } @@ -1816,8 +1695,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { || | ||NULL | |----------- - |""".stripMargin - ) + |""".stripMargin) val df2 = Seq(("line1\nline1.1\n", 1), ("line2", 2), ("\n", 3), ("line4", 4), ("\n\n", 5), (null, 6)) @@ -1839,8 +1717,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { || | | ||NULL |6 | |----------------- - |""".stripMargin - ) + |""".stripMargin) } test("negative test to input invalid table name for saveAsTable()") { @@ -1909,8 +1786,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ("country B", "state A", 100), ("country B", "state A", 100), ("country B", "state B", 10), - ("country B", "state B", 10) - ) + ("country B", "state B", 10)) .toDF("country", "state", "value") val expectedResult = Seq( @@ -1920,16 +1796,14 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row("country A", "state B", 10), Row("country B", null, 220), Row("country B", "state A", 200), - Row("country B", "state B", 20) - ) + Row("country B", "state B", 20)) checkAnswer( df.rollup(Array($"country", $"state")) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - sort = false - ) + sort = false) } test("rollup(String) with array args") { val df = Seq( @@ -1940,8 +1814,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ("country B", "state A", 100), ("country B", "state A", 100), ("country B", "state B", 10), - ("country B", "state B", 10) - ) + ("country B", "state B", 10)) .toDF("country", "state", "value") val expectedResult = Seq( @@ -1951,16 +1824,14 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Row("country A", "state B", 10), Row("country B", null, 220), Row("country B", "state A", 200), - Row("country B", "state B", 20) - ) + Row("country B", "state B", 20)) checkAnswer( df.rollup(Array("country", "state")) .agg(sum(col("value"))) .sort(col("country"), col("state")), expectedResult, - sort = false - ) + sort = false) } test("groupBy with array args") { @@ -1972,22 +1843,19 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ("country B", "state A", 100), ("country B", "state A", 100), ("country B", "state B", 10), - ("country B", "state B", 10) - ) + ("country B", "state B", 10)) .toDF("country", "state", "value") val expectedResult = Seq( Row("country A", "state A", 100), Row("country A", "state B", 10), Row("country B", "state A", 200), - Row("country B", "state B", 20) - ) + Row("country B", "state B", 20)) checkAnswer( df.groupBy(Array($"country", $"state")) .agg(sum(col("value"))), - expectedResult - ) + expectedResult) } test("groupBy(String) with array args") { @@ -1999,22 +1867,19 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ("country B", "state A", 100), ("country B", "state A", 100), ("country B", "state B", 10), - ("country B", "state B", 10) - ) + ("country B", "state B", 10)) .toDF("country", "state", "value") val expectedResult = Seq( Row("country A", "state A", 100), Row("country A", "state B", 10), Row("country B", "state A", 200), - Row("country B", "state B", 20) - ) + Row("country B", "state B", 20)) checkAnswer( df.groupBy(Array("country", "state")) .agg(sum(col("value"))), - expectedResult - ) + expectedResult) } test("test rename: basic") { @@ -2029,8 +1894,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { s"""root | |--A: Long (nullable = false) | |--B1: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df2, Seq(Row(1, 2))) // rename column 'a as 'a1 @@ -2041,8 +1905,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { s"""root | |--A1: Long (nullable = false) | |--B1: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df2, Seq(Row(1, 2))) } @@ -2065,8 +1928,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | |--LEFT_B: Long (nullable = false) | |--RIGHT_A: Long (nullable = false) | |--RIGHT_C: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df2, Seq(Row(1, 2, 3, 4))) // Get columns for right DF's columns @@ -2077,8 +1939,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { s"""root | |--RIGHT_A: Long (nullable = false) | |--RIGHT_C: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df3, Seq(Row(3, 4))) } @@ -2096,8 +1957,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { | |--B1: Long (nullable = false) | |--A: Long (nullable = false) | |--B: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df5, Seq(Row(1, 2, 1, 2))) } @@ -2112,9 +1972,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { ex1.errorCode.equals("0120") && ex1.message.contains( "Unable to rename the column Column[Literal(c,Some(String))] as \"C\"" + - " because this DataFrame doesn't have a column named Column[Literal(c,Some(String))]." - ) - ) + " because this DataFrame doesn't have a column named Column[Literal(c,Some(String))].")) // rename un-exist column val ex2 = intercept[SnowparkClientException] { @@ -2122,11 +1980,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { } assert( ex2.errorCode.equals("0120") && - ex2.message.contains( - "Unable to rename the column \"NOT_EXIST_COLUMN\" as \"C\"" + - " because this DataFrame doesn't have a column named \"NOT_EXIST_COLUMN\"." - ) - ) + ex2.message.contains("Unable to rename the column \"NOT_EXIST_COLUMN\" as \"C\"" + + " because this DataFrame doesn't have a column named \"NOT_EXIST_COLUMN\".")) // rename a column has 3 duplicate names in the DataFrame val df2 = session.sql("select 1 as A, 2 as A, 3 as A") @@ -2135,17 +1990,13 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { } assert( ex3.errorCode.equals("0121") && - ex3.message.contains( - "Unable to rename the column \"A\" as \"C\" because" + - " this DataFrame has 3 columns named \"A\"" - ) - ) + ex3.message.contains("Unable to rename the column \"A\" as \"C\" because" + + " this DataFrame has 3 columns named \"A\"")) } test("with columns keep order", JavaStoredProcExclude) { val data = new Variant( - Map("STARTTIME" -> 0, "ENDTIME" -> 10000, "START_STATION_ID" -> 2, "END_STATION_ID" -> 3) - ) + Map("STARTTIME" -> 0, "ENDTIME" -> 10000, "START_STATION_ID" -> 2, "END_STATION_ID" -> 3)) val df = Seq((1, data)).toDF("TRIPID", "V") val result = df.withColumns( @@ -2155,9 +2006,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { to_timestamp(get(col("V"), lit("ENDTIME"))), datediff("minute", col("STARTTIME"), col("ENDTIME")), as_integer(get(col("V"), lit("START_STATION_ID"))), - as_integer(get(col("V"), lit("END_STATION_ID"))) - ) - ) + as_integer(get(col("V"), lit("END_STATION_ID"))))) checkAnswer( result, @@ -2170,10 +2019,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { Timestamp.valueOf("1969-12-31 18:46:40.0"), 166, 2, - 3 - ) - ) - ) + 3))) } test("withColumns input doesn't match each other") { @@ -2181,9 +2027,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { val msg = intercept[SnowparkClientException](df.withColumns(Seq("e", "f"), Seq(lit(1)))) assert( msg.message.contains( - "The number of column names (2) does not match the number of values (1)." - ) - ) + "The number of column names (2) does not match the number of values (1).")) } test("withColumns replace exiting") { @@ -2192,20 +2036,16 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { checkAnswer(replaced, Seq(Row(1, 3, 5, 6))) val msg = intercept[SnowparkClientException]( - df.withColumns(Seq("d", "b", "d"), Seq(lit(4), lit(5), lit(6))) - ) + df.withColumns(Seq("d", "b", "d"), Seq(lit(4), lit(5), lit(6)))) assert( - msg.message.contains("The same column name is used multiple times in the colNames parameter.") - ) + msg.message.contains( + "The same column name is used multiple times in the colNames parameter.")) val msg1 = intercept[SnowparkClientException]( - df.withColumns(Seq("d", "b", "D"), Seq(lit(4), lit(5), lit(6))) - ) + df.withColumns(Seq("d", "b", "D"), Seq(lit(4), lit(5), lit(6)))) assert( msg1.message.contains( - "The same column name is used multiple times in the colNames parameter." - ) - ) + "The same column name is used multiple times in the colNames parameter.")) } test("dropDuplicates") { @@ -2213,8 +2053,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { .toDF("a", "b", "c", "d") checkAnswer( df.dropDuplicates(), - Seq(Row(1, 1, 1, 1), Row(1, 1, 1, 2), Row(1, 1, 2, 3), Row(1, 2, 3, 4)) - ) + Seq(Row(1, 1, 1, 1), Row(1, 1, 1, 2), Row(1, 1, 2, 3), Row(1, 2, 3, 4))) val result1 = df.dropDuplicates("a") assert(result1.count() == 1) @@ -2225,7 +2064,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { case (1, 1, 1, 2) => case (1, 1, 2, 3) => case (1, 2, 3, 4) => - case _ => throw new Exception("wrong result") + case _ => throw new Exception("wrong result") } val result2 = df.dropDuplicates("a", "b") @@ -2237,7 +2076,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { case (1, 1, 1, 1) => case (1, 1, 1, 2) => case (1, 1, 2, 3) => - case _ => throw new Exception("wrong result") + case _ => throw new Exception("wrong result") } val result3 = df.dropDuplicates("a", "b", "c") @@ -2249,18 +2088,16 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { (row3.getInt(0), row3.getInt(1), row3.getInt(2), row3.getInt(3)) match { case (1, 1, 1, 1) => case (1, 1, 1, 2) => - case _ => throw new Exception("wrong result") + case _ => throw new Exception("wrong result") } checkAnswer( df.dropDuplicates("a", "b", "c", "d"), - Seq(Row(1, 1, 1, 1), Row(1, 1, 1, 2), Row(1, 1, 2, 3), Row(1, 2, 3, 4)) - ) + Seq(Row(1, 1, 1, 1), Row(1, 1, 1, 2), Row(1, 1, 2, 3), Row(1, 2, 3, 4))) checkAnswer( df.dropDuplicates("a", "b", "c", "d", "d"), - Seq(Row(1, 1, 1, 1), Row(1, 1, 1, 2), Row(1, 1, 2, 3), Row(1, 2, 3, 4)) - ) + Seq(Row(1, 1, 1, 1), Row(1, 1, 1, 2), Row(1, 1, 2, 3), Row(1, 2, 3, 4))) // column doesn't exist assertThrows[SnowparkClientException](df.dropDuplicates("e").collect()) @@ -2283,7 +2120,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { case (1, 1, 1, 2) => case (1, 1, 2, 3) => case (1, 2, 3, 4) => - case _ => throw new Exception("wrong result") + case _ => throw new Exception("wrong result") } } @@ -2297,9 +2134,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { (8, 5, 11, "green", 99), (8, 4, 14, "blue", 99), (8, 3, 21, "red", 99), - (9, 9, 12, "orange", 99) - ) - ) + (9, 9, 12, "orange", 99))) .toDF("v1", "v2", "length", "color", "unused") // Wrapped JDBC exception @@ -2324,11 +2159,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { } assert(ex.errorCode.equals("0108")) assert( - ex.message.contains( - "The DataFrame does not contain the column named" + - " 'NOT_EXIST_COL' and the valid names are \"A\", \"B\"" - ) - ) + ex.message.contains("The DataFrame does not contain the column named" + + " 'NOT_EXIST_COL' and the valid names are \"A\", \"B\"")) } } @@ -2336,8 +2168,7 @@ class EagerDataFrameSuite extends DataFrameSuite with EagerSession { test("eager analysis") { // reports errors assertThrows[SnowflakeSQLException]( - session.sql("select something").select("111").filter(col("+++") > "aaa") - ) + session.sql("select something").select("111").filter(col("+++") > "aaa")) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameWriterSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameWriterSuite.scala index c9bcce01..f5aab2b7 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameWriterSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameWriterSuite.scala @@ -17,8 +17,7 @@ class DataFrameWriterSuite extends TestData { val tableName = randomTableName() private val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType))) override def beforeAll(): Unit = { super.beforeAll() @@ -191,8 +190,7 @@ class DataFrameWriterSuite extends TestData { options: Map[String, Any], expectedWriteResult: Array[Row], outputFileExtension: String, - saveMode: Option[SaveMode] = None - ) = { + saveMode: Option[SaveMode] = None) = { // Execute COPY INTO location and check result val writer = df.write.options(options) saveMode.foreach(writer.mode) @@ -210,8 +208,7 @@ class DataFrameWriterSuite extends TestData { options: Map[String, Any], expectedWriteResult: Array[Row], outputFileExtension: String, - saveMode: Option[SaveMode] = None - ) = { + saveMode: Option[SaveMode] = None) = { // Execute COPY INTO location and check result val writer = df.write.options(options) saveMode.foreach(writer.mode) @@ -229,8 +226,7 @@ class DataFrameWriterSuite extends TestData { options: Map[String, Any], expectedNumberOfRow: Int, outputFileExtension: String, - saveMode: Option[SaveMode] = None - ) = { + saveMode: Option[SaveMode] = None) = { // Execute COPY INTO location and check result val writer = df.write.options(options) saveMode.foreach(writer.mode) @@ -249,9 +245,7 @@ class DataFrameWriterSuite extends TestData { Seq( StructField("c1", IntegerType), StructField("c2", DoubleType), - StructField("c3", StringType) - ) - ) + StructField("c3", StringType))) val df = session.table(tableName) val path = s"@$targetStageName/p_${Random.nextInt().abs}" @@ -259,8 +253,7 @@ class DataFrameWriterSuite extends TestData { runCSvTest(df, path, Map.empty, Array(Row(3, 32, 46)), ".csv.gz") checkAnswer( session.read.schema(schema).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // by default, the mode is ErrorIfExist val ex = intercept[SnowflakeSQLException] { @@ -272,8 +265,7 @@ class DataFrameWriterSuite extends TestData { runCSvTest(df, path, Map.empty, Array(Row(3, 32, 46)), ".csv.gz", Some(SaveMode.Overwrite)) checkAnswer( session.read.schema(schema).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // test some file format options and values session.sql(s"remove $path").collect() @@ -281,13 +273,11 @@ class DataFrameWriterSuite extends TestData { "FIELD_DELIMITER" -> "'aa'", "RECORD_DELIMITER" -> "bbbb", "COMPRESSION" -> "NONE", - "FILE_EXTENSION" -> "mycsv" - ) + "FILE_EXTENSION" -> "mycsv") runCSvTest(df, path, options1, Array(Row(3, 47, 47)), ".mycsv") checkAnswer( session.read.schema(schema).options(options1).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // Test file format name only val fileFormatName = randomTableName() @@ -295,8 +285,7 @@ class DataFrameWriterSuite extends TestData { .sql( s"CREATE OR REPLACE TEMPORARY FILE FORMAT $fileFormatName " + s"TYPE = CSV FIELD_DELIMITER = 'aa' RECORD_DELIMITER = 'bbbb' " + - s"COMPRESSION = 'NONE' FILE_EXTENSION = 'mycsv'" - ) + s"COMPRESSION = 'NONE' FILE_EXTENSION = 'mycsv'") .collect() runCSvTest( df, @@ -304,20 +293,17 @@ class DataFrameWriterSuite extends TestData { Map("FORMAT_NAME" -> fileFormatName), Array(Row(3, 47, 47)), ".mycsv", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.schema(schema).options(options1).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // Test file format name and some extra format options val fileFormatName2 = randomTableName() session .sql( s"CREATE OR REPLACE TEMPORARY FILE FORMAT $fileFormatName2 " + - s"TYPE = CSV FIELD_DELIMITER = 'aa' RECORD_DELIMITER = 'bbbb'" - ) + s"TYPE = CSV FIELD_DELIMITER = 'aa' RECORD_DELIMITER = 'bbbb'") .collect() val formatNameAndOptions = Map("FORMAT_NAME" -> fileFormatName2, "COMPRESSION" -> "NONE", "FILE_EXTENSION" -> "mycsv") @@ -327,12 +313,10 @@ class DataFrameWriterSuite extends TestData { formatNameAndOptions, Array(Row(3, 47, 47)), ".mycsv", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.schema(schema).options(options1).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) } // copyOptions ::= @@ -348,9 +332,7 @@ class DataFrameWriterSuite extends TestData { Seq( StructField("c1", IntegerType), StructField("c2", DoubleType), - StructField("c3", StringType) - ) - ) + StructField("c3", StringType))) val df = session.table(tableName) val path = s"@$targetStageName/p_${Random.nextInt().abs}" @@ -360,8 +342,7 @@ class DataFrameWriterSuite extends TestData { runCSvTest(df, path2, Map("SINGLE" -> true), Array(Row(3, 32, 46)), targetFile) checkAnswer( session.read.schema(schema).csv(path2), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) // other copy options session.sql(s"rm $path").collect() @@ -374,8 +355,7 @@ class DataFrameWriterSuite extends TestData { assert(resultFiles.length == 1 && resultFiles(0).getString(0).contains(queryId)) checkAnswer( session.read.schema(schema).csv(path), - Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null)) - ) + Seq(Row(1, 1.1, "one"), Row(2, 2.2, "two"), Row(null, null, null))) } // sub clause: @@ -400,8 +380,7 @@ class DataFrameWriterSuite extends TestData { | ,('2020-01-28', null) | ,('2020-01-29', '02:15') |""".stripMargin, - session - ) + session) val schema = StructType(Seq(StructField("dt", DateType), StructField("tm", TimeType))) val df = session.table(tableName) val path = s"@$targetStageName/p_${Random.nextInt().abs}" @@ -410,8 +389,7 @@ class DataFrameWriterSuite extends TestData { Map( "header" -> true, "partition by" -> ("('date=' || to_varchar(dt, 'YYYY-MM-DD') ||" + - " '/hour=' || to_varchar(date_part(hour, ts)))") - ) + " '/hour=' || to_varchar(date_part(hour, ts)))")) val copyResult = df.write.options(options).csv(path).rows checkResult(copyResult, Array(Row(4, 99, 179))) @@ -423,9 +401,7 @@ class DataFrameWriterSuite extends TestData { Row(Date.valueOf("2020-01-26"), Time.valueOf("18:05:00")), Row(Date.valueOf("2020-01-27"), Time.valueOf("22:57:00")), Row(Date.valueOf("2020-01-28"), null), - Row(Date.valueOf("2020-01-29"), Time.valueOf("02:15:00")) - ) - ) + Row(Date.valueOf("2020-01-29"), Time.valueOf("02:15:00")))) } finally { TimeZone.setDefault(oldTimeZone) session.sql(s"alter session set TIMEZONE = '$sfTimezone'").collect() @@ -476,16 +452,14 @@ class DataFrameWriterSuite extends TestData { dfReadFile.write.csv(path) checkAnswer( session.read.schema(userSchema).csv(path), - Seq(Row(1, "one", 1.2), Row(2, "two", 2.2)) - ) + Seq(Row(1, "one", 1.2), Row(2, "two", 2.2))) // read with copy options runQuery(s"rm $path", session) val dfReadCopy = session.read.schema(userSchema).option("PURGE", false).csv(testFileOnStage) dfReadCopy.write.csv(path) checkAnswer( session.read.schema(userSchema).csv(path), - Seq(Row(1, "one", 1.2), Row(2, "two", 2.2)) - ) + Seq(Row(1, "one", 1.2), Row(2, "two", 2.2))) } test("negative test") { @@ -500,27 +474,21 @@ class DataFrameWriterSuite extends TestData { } assert( ex.getMessage.contains( - "DataFrameWriter doesn't support option 'OVERWRITE' when writing to a file." - ) - ) + "DataFrameWriter doesn't support option 'OVERWRITE' when writing to a file.")) val ex2 = intercept[SnowparkClientException] { df.write.option("TYPE", "CSV").csv(path) } assert( ex2.getMessage.contains( - "DataFrameWriter doesn't support option 'TYPE' when writing to a file." - ) - ) + "DataFrameWriter doesn't support option 'TYPE' when writing to a file.")) val ex3 = intercept[SnowparkClientException] { df.write.option("unknown", "abc").csv(path) } assert( ex3.getMessage.contains( - "DataFrameWriter doesn't support option 'UNKNOWN' when writing to a file." - ) - ) + "DataFrameWriter doesn't support option 'UNKNOWN' when writing to a file.")) // only support ErrorIfExists and Overwrite mode val ex4 = intercept[SnowparkClientException] { @@ -528,9 +496,7 @@ class DataFrameWriterSuite extends TestData { } assert( ex4.getMessage.contains( - "DataFrameWriter doesn't support mode 'Append' when writing to a file." - ) - ) + "DataFrameWriter doesn't support mode 'Append' when writing to a file.")) } // JSON can only be used to unload data from columns of type VARIANT @@ -545,8 +511,7 @@ class DataFrameWriterSuite extends TestData { runJsonTest(df, path, Map.empty, Array(Row(2, 20, 40)), ".json.gz") checkAnswer( session.read.json(path), - Seq(Row("[\n 1,\n \"one\"\n]"), Row("[\n 2,\n \"two\"\n]")) - ) + Seq(Row("[\n 1,\n \"one\"\n]"), Row("[\n 2,\n \"two\"\n]"))) // write one column and overwrite val df2 = session.table(tableName).select(to_variant(col("c2"))) @@ -563,8 +528,7 @@ class DataFrameWriterSuite extends TestData { Map("FORMAT_NAME" -> formatName), Array(Row(2, 4, 24)), ".json.gz", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) session.read.json(path).show() checkAnswer(session.read.json(path), Seq(Row("1"), Row("2"))) @@ -577,8 +541,7 @@ class DataFrameWriterSuite extends TestData { Map("FORMAT_NAME" -> formatName, "FILE_EXTENSION" -> "myjson.json", "COMPRESSION" -> "NONE"), Array(Row(2, 4, 4)), ".myjson.json", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) session.read.json(path).show() checkAnswer(session.read.json(path), Seq(Row("1"), Row("2"))) } @@ -595,9 +558,7 @@ class DataFrameWriterSuite extends TestData { session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) // write with overwrite runParquetTest(df, path, Map.empty, 2, ".snappy.parquet", Some(SaveMode.Overwrite)) @@ -605,9 +566,7 @@ class DataFrameWriterSuite extends TestData { session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) // write with format_name val formatName = randomTableName() @@ -620,15 +579,12 @@ class DataFrameWriterSuite extends TestData { Map("FORMAT_NAME" -> formatName), 2, ".snappy.parquet", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) // write with format_name format and some extra option session.sql(s"rm $path").collect() @@ -639,14 +595,11 @@ class DataFrameWriterSuite extends TestData { Map("FORMAT_NAME" -> formatName, "COMPRESSION" -> "LZO"), 2, ".lzo.parquet", - Some(SaveMode.Overwrite) - ) + Some(SaveMode.Overwrite)) checkAnswer( session.read.parquet(path), Seq( Row("{\n \"_COL_0\": 1,\n \"_COL_1\": \"one\"\n}"), - Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}") - ) - ) + Row("{\n \"_COL_0\": 2,\n \"_COL_1\": \"two\"\n}"))) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala index 2b410686..1af50f2d 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala @@ -91,8 +91,7 @@ class DataTypeSuite extends SNTestBase { assert(tpe.typeName == "Struct") assert( tpe.toString == "StructType[StructField(COL1, Integer, Nullable = true), " + - "StructField(COL2, String, Nullable = false)]" - ) + "StructField(COL2, String, Nullable = false)]") assert(tpe(1) == StructField("col2", StringType, nullable = false)) assert(tpe("col1") == StructField("col1", IntegerType)) @@ -116,30 +115,22 @@ class DataTypeSuite extends SNTestBase { StructField( "col14", StructType(Seq(StructField("col15", TimestampType, nullable = false))), - nullable = false - ), + nullable = false), StructField("col3", DateType, nullable = false), StructField( "col4", - StructType( - Seq( - StructField("col5", ByteType), - StructField("col6", ShortType), - StructField("col7", IntegerType, nullable = false), - StructField("col8", LongType), - StructField( - "col12", - StructType(Seq(StructField("col13", StringType))), - nullable = false - ), - StructField("col9", FloatType), - StructField("col10", DoubleType), - StructField("col11", DecimalType(10, 1)) - ) - ) - ) - ) - ) + StructType(Seq( + StructField("col5", ByteType), + StructField("col6", ShortType), + StructField("col7", IntegerType, nullable = false), + StructField("col8", LongType), + StructField( + "col12", + StructType(Seq(StructField("col13", StringType))), + nullable = false), + StructField("col9", FloatType), + StructField("col10", DoubleType), + StructField("col11", DecimalType(10, 1))))))) assert( TestUtils.treeString(schema, 0) == @@ -159,8 +150,7 @@ class DataTypeSuite extends SNTestBase { | |--COL9: Float (nullable = true) | |--COL10: Double (nullable = true) | |--COL11: Decimal(10, 1) (nullable = true) - |""".stripMargin - ) + |""".stripMargin) } test("ColumnIdentifier") { @@ -190,8 +180,7 @@ class DataTypeSuite extends SNTestBase { s"""root | |--A: Decimal(5, 2) (nullable = false) | |--B: Decimal(7, 2) (nullable = false) - |""".stripMargin - ) + |""".stripMargin) } test("read Structured Array") { @@ -231,9 +220,7 @@ class DataTypeSuite extends SNTestBase { Array("{\n \"name\": 1\n}"), Array(Array(1L, 2L), Array(3L, 4L)), Array(java.math.BigDecimal.valueOf(1.234)), - Array(Time.valueOf("10:03:56")) - ) - ) + Array(Time.valueOf("10:03:56")))) } finally { TimeZone.setDefault(oldTimeZone) } @@ -260,9 +247,7 @@ class DataTypeSuite extends SNTestBase { Map(2 -> Array(4L, 5L, 6L), 1 -> Array(1L, 2L, 3L)), Map(2 -> Map("c" -> 3), 1 -> Map("a" -> 1, "b" -> 2)), Array(Map("a" -> 1, "b" -> 2), Map("c" -> 3)), - "{\n \"a\": 1,\n \"b\": 2\n}" - ) - ) + "{\n \"a\": 1,\n \"b\": 2\n}")) } } @@ -366,8 +351,7 @@ class DataTypeSuite extends SNTestBase { | |--ARR10: Array[Map nullable = true] (nullable = true) | |--ARR11: Array[Array[Long nullable = true] nullable = true] (nullable = true) | |--ARR0: Array (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // schema string: nullable assert( // since we retrieved the schema of df before, df.select("*") will use the @@ -386,8 +370,7 @@ class DataTypeSuite extends SNTestBase { | |--ARR10: Array[Map nullable = true] (nullable = true) | |--ARR11: Array[Array[Long nullable = true] nullable = true] (nullable = true) | |--ARR0: Array (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // schema string: not nullable val query2 = @@ -402,8 +385,7 @@ class DataTypeSuite extends SNTestBase { s"""root | |--ARR1: Array[Long nullable = false] (nullable = true) | |--ARR11: Array[Array[Long nullable = false] nullable = false] (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on assert( @@ -412,8 +394,7 @@ class DataTypeSuite extends SNTestBase { s"""root | |--ARR1: Array[Long nullable = false] (nullable = true) | |--ARR11: Array[Array[Long nullable = false] nullable = false] (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } } @@ -438,8 +419,7 @@ class DataTypeSuite extends SNTestBase { | |--MAP3: Map[Long, Array[Long nullable = true] nullable = true] (nullable = true) | |--MAP4: Map[Long, Map[String, Long nullable = true] nullable = true] (nullable = true) | |--MAP0: Map (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on assert( @@ -453,8 +433,7 @@ class DataTypeSuite extends SNTestBase { | |--MAP3: Map[Long, Array[Long nullable = true] nullable = true] (nullable = true) | |--MAP4: Map[Long, Map[String, Long nullable = true] nullable = true] (nullable = true) | |--MAP0: Map (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on // nullable @@ -472,8 +451,7 @@ class DataTypeSuite extends SNTestBase { | |--MAP1: Map[String, Long nullable = false] (nullable = true) | |--MAP3: Map[Long, Array[Long nullable = false] nullable = true] (nullable = true) | |--MAP4: Map[Long, Map[String, Long nullable = false] nullable = true] (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on assert( @@ -483,8 +461,7 @@ class DataTypeSuite extends SNTestBase { | |--MAP1: Map[String, Long nullable = false] (nullable = true) | |--MAP3: Map[Long, Array[Long nullable = false] nullable = true] (nullable = true) | |--MAP4: Map[Long, Map[String, Long nullable = false] nullable = true] (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } } @@ -519,8 +496,7 @@ class DataTypeSuite extends SNTestBase { | |--A: Struct (nullable = true) | |--B: Struct (nullable = true) | |--C: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on // schema string: nullable @@ -542,8 +518,7 @@ class DataTypeSuite extends SNTestBase { | |--A: Struct (nullable = true) | |--B: Struct (nullable = true) | |--C: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on // schema query: not null @@ -576,8 +551,7 @@ class DataTypeSuite extends SNTestBase { | |--A: Struct (nullable = false) | |--B: Struct (nullable = false) | |--C: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on assert( @@ -598,8 +572,7 @@ class DataTypeSuite extends SNTestBase { | |--A: Struct (nullable = false) | |--B: Struct (nullable = false) | |--C: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala index 8ddeeac8..4409643c 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala @@ -34,8 +34,7 @@ class FileOperationSuite extends SNTestBase { private def createTempFile( prefix: String = "test_file_", suffix: String = ".csv", - content: String = "abc, 123,\n" - ): String = { + content: String = "abc, 123,\n"): String = { val file = File.createTempFile(prefix, suffix, sourceDirectoryFile) FileUtils.write(file, content) file.getCanonicalPath @@ -115,8 +114,7 @@ class FileOperationSuite extends SNTestBase { assert( secondResult(0).message.isEmpty || secondResult(0).message - .contains("File with same destination name and checksum already exists") - ) + .contains("File with same destination name and checksum already exists")) } test("put() with one relative path file") { @@ -181,8 +179,7 @@ class FileOperationSuite extends SNTestBase { } assert( stageNotExistException.getMessage.contains("Stage") && - stageNotExistException.getMessage.contains("does not exist or not authorized.") - ) + stageNotExistException.getMessage.contains("does not exist or not authorized.")) } test("get() one file") { @@ -202,8 +199,7 @@ class FileOperationSuite extends SNTestBase { results(0).sizeBytes == 30L && results(0).status.equals("DOWNLOADED") && results(0).encryption.equals("DECRYPTED") && - results(0).message.equals("") - ) + results(0).message.equals("")) // Check downloaded file assert(fileExists(s"$targetDirectoryPath/${getFileName(path1)}.gz")) @@ -232,8 +228,7 @@ class FileOperationSuite extends SNTestBase { results(0).sizeBytes == 30L && results(0).status.equals("DOWNLOADED") && results(0).encryption.equals("DECRYPTED") && - results(0).message.equals("") - ) + results(0).message.equals("")) // Check downloaded file assert(fileExists(s"$targetDirectoryPath/${getFileName(path1)}.gz")) @@ -267,9 +262,7 @@ class FileOperationSuite extends SNTestBase { assert( r.status.equals("DOWNLOADED") && r.encryption.equals("DECRYPTED") && - r.message.equals("") - ) - ) + r.message.equals(""))) // Check downloaded files assert(fileExists(s"$targetDirectoryPath/${getFileName(path1)}.gz")) @@ -305,9 +298,7 @@ class FileOperationSuite extends SNTestBase { assert( r.status.equals("DOWNLOADED") && r.encryption.equals("DECRYPTED") && - r.message.equals("") - ) - ) + r.message.equals(""))) // Check downloaded files assert(fileExists(getFileName(path1) + ".gz")) @@ -329,8 +320,7 @@ class FileOperationSuite extends SNTestBase { } assert( stageNotExistException.getMessage.contains("Stage") && - stageNotExistException.getMessage.contains("does not exist or not authorized.") - ) + stageNotExistException.getMessage.contains("does not exist or not authorized.")) // If stage name exists but prefix doesn't exist, download nothing var getResults = session.file.get(s"@$tempStage/not_exist_prefix_test/", ".") @@ -366,8 +356,7 @@ class FileOperationSuite extends SNTestBase { assert( results(0).sizeBytes == 30L && results(0).status.equals("DOWNLOADED") && - results(0).message.equals("") - ) + results(0).message.equals("")) // The error message is like: // prefix/prefix_1/file_1.csv.gz has same name as prefix/prefix_2/file_1.csv.gz // GET on GCP doesn't detect this download collision. @@ -377,8 +366,7 @@ class FileOperationSuite extends SNTestBase { results(1).message.contains("has same name as")) || (results(1).sizeBytes == 30L && results(1).status.equals("DOWNLOADED") && - results(1).message.equals("")) - ) + results(1).message.equals(""))) // Check downloaded files assert(fileExists(targetDirectoryPath + "/" + (getFileName(path1) + ".gz"))) @@ -424,32 +412,28 @@ class FileOperationSuite extends SNTestBase { testStreamRoundTrip( s"$tempStage/$stagePrefix/$fileName", s"$tempStage/$stagePrefix/$fileName", - false - ) + false) // Test with @ prefix stagePrefix = "prefix_" + TestUtils.randomString(5) testStreamRoundTrip( s"@$tempStage/$stagePrefix/$fileName", s"@$tempStage/$stagePrefix/$fileName", - false - ) + false) // Test compression with .gz extension stagePrefix = "prefix_" + TestUtils.randomString(5) testStreamRoundTrip( s"$tempStage/$stagePrefix/$fileName.gz", s"$tempStage/$stagePrefix/$fileName.gz", - true - ) + true) // Test compression without .gz extension stagePrefix = "prefix_" + TestUtils.randomString(5) testStreamRoundTrip( s"$tempStage/$stagePrefix/$fileName", s"$tempStage/$stagePrefix/$fileName.gz", - true - ) + true) // Test no path fileName = s"streamFile_${TestUtils.randomString(5)}.csv" @@ -462,8 +446,7 @@ class FileOperationSuite extends SNTestBase { testStreamRoundTrip( s"$database.$schema.$tempStage/$fileName", s"$database.$schema.$tempStage/$fileName.gz", - true - ) + true) fileName = s"streamFile_${TestUtils.randomString(5)}.csv" testStreamRoundTrip(s"$schema.$tempStage/$fileName", s"$schema.$tempStage/$fileName.gz", true) @@ -482,8 +465,7 @@ class FileOperationSuite extends SNTestBase { testStreamRoundTrip( s"$randomNewSchema.$tempStage/$fileName", s"$randomNewSchema.$tempStage/$fileName.gz", - true - ) + true) } finally { session.sql(s"DROP SCHEMA $randomNewSchema").collect() } @@ -496,49 +478,42 @@ class FileOperationSuite extends SNTestBase { // Test no file name assertThrows[SnowparkClientException]( - testStreamRoundTrip(s"$tempStage/", s"$tempStage/", false) - ) + testStreamRoundTrip(s"$tempStage/", s"$tempStage/", false)) var stagePrefix = "prefix_" + TestUtils.randomString(5) var fileName = s"streamFile_${TestUtils.randomString(5)}.csv" // Test upload no stage assertThrows[SnowflakeSQLException]( - testStreamRoundTrip(s"nonExistStage/$fileName", s"nonExistStage/$fileName", false) - ) + testStreamRoundTrip(s"nonExistStage/$fileName", s"nonExistStage/$fileName", false)) // Test download no stage assertThrows[SnowflakeSQLException]( - testStreamRoundTrip(s"$tempStage/$fileName", s"nonExistStage/$fileName", false) - ) + testStreamRoundTrip(s"$tempStage/$fileName", s"nonExistStage/$fileName", false)) stagePrefix = "prefix_" + TestUtils.randomString(5) fileName = s"streamFile_${TestUtils.randomString(5)}.csv" // Test download no file assertThrows[SnowparkClientException]( - testStreamRoundTrip(s"$tempStage/$fileName", s"$tempStage/$stagePrefix/$fileName", false) - ) + testStreamRoundTrip(s"$tempStage/$fileName", s"$tempStage/$stagePrefix/$fileName", false)) } private def testStreamRoundTrip( uploadLocation: String, downloadLocation: String, - compress: Boolean - ): Unit = { + compress: Boolean): Unit = { val fileContent = "test, file, csv" session.file.uploadStream( uploadLocation, new ByteArrayInputStream(fileContent.getBytes), - compress - ) + compress) assert( Source .fromInputStream(session.file.downloadStream(downloadLocation, compress)) .mkString - .equals(fileContent) - ) + .equals(fileContent)) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 1260ef0e..d5f47944 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -34,8 +34,7 @@ trait FunctionSuite extends TestData { test("corr") { checkAnswer( number1.groupBy(col("K")).agg(corr(col("V1"), col("v2"))), - Seq(Row(1, null), Row(2, 0.40367115665231024)) - ) + Seq(Row(1, null), Row(2, 0.40367115665231024))) } test("count") { @@ -51,13 +50,11 @@ trait FunctionSuite extends TestData { test("covariance") { checkAnswer( number1.groupBy("K").agg(covar_pop(col("V1"), col("V2"))), - Seq(Row(1, 0), Row(2, 38.75)) - ) + Seq(Row(1, 0), Row(2, 38.75))) checkAnswer( number1.groupBy("K").agg(covar_samp(col("V1"), col("V2"))), - Seq(Row(1, null), Row(2, 51.666666666666664)) - ) + Seq(Row(1, null), Row(2, 51.666666666666664))) } test("grouping") { @@ -73,17 +70,14 @@ trait FunctionSuite extends TestData { Row(2, null, 0, 1, 1), Row(null, null, 1, 1, 3), Row(null, 2, 1, 0, 2), - Row(null, 1, 1, 0, 2) - ), - sort = false - ) + Row(null, 1, 1, 0, 2)), + sort = false) } test("kurtosis") { checkAnswer( xyz.select(kurtosis(col("X")), kurtosis(col("Y")), kurtosis(col("Z"))), - Seq(Row(-3.333333333333, 5.000000000000, 3.613736609956)) - ) + Seq(Row(-3.333333333333, 5.000000000000, 3.613736609956))) } test("max, min, mean") { @@ -106,99 +100,85 @@ trait FunctionSuite extends TestData { test("skew") { checkAnswer( xyz.select(skew(col("X")), skew(col("Y")), skew(col("Z"))), - Seq(Row(-0.6085811063146803, -2.236069766354172, 1.8414236309018863)) - ) + Seq(Row(-0.6085811063146803, -2.236069766354172, 1.8414236309018863))) } test("stddev") { checkAnswer( xyz.select(stddev(col("X")), stddev_samp(col("Y")), stddev_pop(col("Z"))), - Seq(Row(0.5477225575051661, 0.4472135954999579, 3.3226495451672298)) - ) + Seq(Row(0.5477225575051661, 0.4472135954999579, 3.3226495451672298))) } test("sum") { checkAnswer( duplicatedNumbers.groupBy("A").agg(sum(col("A"))), Seq(Row(3, 6), Row(2, 4), Row(1, 1)), - sort = false - ) + sort = false) checkAnswer( duplicatedNumbers.groupBy("A").agg(sum("A")), Seq(Row(3, 6), Row(2, 4), Row(1, 1)), - sort = false - ) + sort = false) checkAnswer( duplicatedNumbers.groupBy("A").agg(sum_distinct(col("A"))), Seq(Row(3, 3), Row(2, 2), Row(1, 1)), - sort = false - ) + sort = false) } test("variance") { checkAnswer( xyz.groupBy("X").agg(variance(col("Y")), var_pop(col("Z")), var_samp(col("Z"))), Seq(Row(1, 0.000000, 1.000000, 2.000000), Row(2, 0.333333, 14.888889, 22.333333)), - sort = false - ) + sort = false) } test("cume_dist") { checkAnswer( xyz.select(cume_dist().over(Window.partitionBy(col("X")).orderBy(col("Y")))), Seq(Row(0.3333333333333333), Row(1.0), Row(1.0), Row(1.0), Row(1.0)), - sort = false - ) + sort = false) } test("dense_rank") { checkAnswer( xyz.select(dense_rank().over(Window.orderBy(col("X")))), Seq(Row(1), Row(1), Row(2), Row(2), Row(2)), - sort = false - ) + sort = false) } test("lag") { checkAnswer( xyz.select(lag(col("Z"), 1, lit(0)).over(Window.partitionBy(col("X")).orderBy(col("X")))), Seq(Row(0), Row(10), Row(1), Row(0), Row(1)), - sort = false - ) + sort = false) checkAnswer( xyz.select(lag(col("Z"), 1).over(Window.partitionBy(col("X")).orderBy(col("X")))), Seq(Row(null), Row(10), Row(1), Row(null), Row(1)), - sort = false - ) + sort = false) checkAnswer( xyz.select(lag(col("Z")).over(Window.partitionBy(col("X")).orderBy(col("X")))), Seq(Row(null), Row(10), Row(1), Row(null), Row(1)), - sort = false - ) + sort = false) } test("lead") { checkAnswer( xyz.select(lead(col("Z"), 1, lit(0)).over(Window.partitionBy(col("X")).orderBy(col("X")))), Seq(Row(1), Row(3), Row(0), Row(3), Row(0)), - sort = false - ) + sort = false) checkAnswer( xyz.select(lead(col("Z"), 1).over(Window.partitionBy(col("X")).orderBy(col("X")))), Seq(Row(1), Row(3), Row(null), Row(3), Row(null)), - sort = false - ) + sort = false) checkAnswer( xyz.select(lead(col("Z")).over(Window.partitionBy(col("X")).orderBy(col("X")))), Seq(Row(1), Row(3), Row(null), Row(3), Row(null)), - sort = false - ) + sort = false) } test("ntile") { @@ -206,48 +186,42 @@ trait FunctionSuite extends TestData { checkAnswer( df.select(ntile(col("n")).over(Window.partitionBy(col("X")).orderBy(col("Y")))), Seq(Row(1), Row(2), Row(3), Row(1), Row(2)), - sort = false - ) + sort = false) } test("percent_rank") { checkAnswer( xyz.select(percent_rank().over(Window.partitionBy(col("X")).orderBy(col("Y")))), Seq(Row(0.0), Row(0.5), Row(0.5), Row(0.0), Row(0.0)), - sort = false - ) + sort = false) } test("rank") { checkAnswer( xyz.select(rank().over(Window.partitionBy(col("X")).orderBy(col("Y")))), Seq(Row(1), Row(2), Row(2), Row(1), Row(1)), - sort = false - ) + sort = false) } test("row_number") { checkAnswer( xyz.select(row_number().over(Window.partitionBy(col("X")).orderBy(col("Y")))), Seq(Row(1), Row(2), Row(3), Row(1), Row(2)), - sort = false - ) + sort = false) } test("coalesce") { checkAnswer( nullData2.select(coalesce(col("A"), col("B"), col("C"))), Seq(Row(1), Row(2), Row(3), Row(null), Row(1), Row(1), Row(1)), - sort = false - ) + sort = false) } test("NaN and Null") { checkAnswer( nanData1.select(equal_nan(col("A")), is_null(col("A"))), Seq(Row(false, false), Row(true, false), Row(null, true), Row(false, false)), - sort = false - ) + sort = false) } test("negate and not") { @@ -255,8 +229,7 @@ trait FunctionSuite extends TestData { checkAnswer( df.select(negate(col("A")), not(col("B"))), Seq(Row(-1, false), Row(2, true)), - sort = false - ) + sort = false) } test("random") { @@ -269,8 +242,7 @@ trait FunctionSuite extends TestData { checkAnswer( testData1.select(sqrt(col("NUM"))), Seq(Row(1.0), Row(1.4142135623730951)), - sort = false - ) + sort = false) } test("bitwise not") { @@ -293,14 +265,12 @@ trait FunctionSuite extends TestData { checkAnswer( xyz.select(greatest(col("X"), col("Y"), col("Z"))), Seq(Row(2), Row(3), Row(10), Row(2), Row(3)), - sort = false - ) + sort = false) checkAnswer( xyz.select(least(col("X"), col("Y"), col("Z"))), Seq(Row(1), Row(1), Row(1), Row(1), Row(2)), - sort = false - ) + sort = false) } test("round") { @@ -315,20 +285,16 @@ trait FunctionSuite extends TestData { Seq( Row(1.4706289056333368, 0.1001674211615598), Row(1.369438406004566, 0.2013579207903308), - Row(1.2661036727794992, 0.3046926540153975) - ), - sort = false - ) + Row(1.2661036727794992, 0.3046926540153975)), + sort = false) checkAnswer( double2.select(acos(col("A")), asin(col("A"))), Seq( Row(1.4706289056333368, 0.1001674211615598), Row(1.369438406004566, 0.2013579207903308), - Row(1.2661036727794992, 0.3046926540153975) - ), - sort = false - ) + Row(1.2661036727794992, 0.3046926540153975)), + sort = false) } test("atan atan2") { @@ -337,17 +303,14 @@ trait FunctionSuite extends TestData { Seq( Row(0.4636476090008061, 0.09966865249116204), Row(0.5404195002705842, 0.19739555984988078), - Row(0.6107259643892086, 0.2914567944778671) - ), - sort = false - ) + Row(0.6107259643892086, 0.2914567944778671)), + sort = false) checkAnswer( double2 .select(atan2(col("B"), col("A"))), Seq(Row(1.373400766945016), Row(1.2490457723982544), Row(1.1659045405098132)), - sort = false - ) + sort = false) } test("cos cosh") { @@ -356,10 +319,8 @@ trait FunctionSuite extends TestData { Seq( Row(0.9950041652780258, 0.9950041652780258, 1.1276259652063807, 1.1276259652063807), Row(0.9800665778412416, 0.9800665778412416, 1.1854652182422676, 1.1854652182422676), - Row(0.955336489125606, 0.955336489125606, 1.255169005630943, 1.255169005630943) - ), - sort = false - ) + Row(0.955336489125606, 0.955336489125606, 1.255169005630943, 1.255169005630943)), + sort = false) } test("exp") { @@ -368,10 +329,8 @@ trait FunctionSuite extends TestData { Seq( Row(2.718281828459045, 2.718281828459045), Row(1.0, 1.0), - Row(0.006737946999085467, 0.006737946999085467) - ), - sort = false - ) + Row(0.006737946999085467, 0.006737946999085467)), + sort = false) } test("factorial") { @@ -382,23 +341,20 @@ trait FunctionSuite extends TestData { checkAnswer( integer1.select(log(lit(2), col("A")), log(lit(4), col("A"))), Seq(Row(0.0, 0.0), Row(1.0, 0.5), Row(1.5849625007211563, 0.7924812503605781)), - sort = false - ) + sort = false) } test("pow") { checkAnswer( double2.select(pow(col("A"), col("B"))), Seq(Row(0.31622776601683794), Row(0.3807307877431757), Row(0.4305116202499342)), - sort = false - ) + sort = false) } test("shiftleft shiftright") { checkAnswer( integer1.select(bitshiftleft(col("A"), lit(1)), bitshiftright(col("A"), lit(1))), Seq(Row(2, 0), Row(4, 1), Row(6, 1)), - sort = false - ) + sort = false) } test("sin sinh") { @@ -407,10 +363,8 @@ trait FunctionSuite extends TestData { Seq( Row(0.09983341664682815, 0.09983341664682815, 0.10016675001984403, 0.10016675001984403), Row(0.19866933079506122, 0.19866933079506122, 0.20133600254109402, 0.20133600254109402), - Row(0.29552020666133955, 0.29552020666133955, 0.3045202934471426, 0.3045202934471426) - ), - sort = false - ) + Row(0.29552020666133955, 0.29552020666133955, 0.3045202934471426, 0.3045202934471426)), + sort = false) } test("tan tanh") { @@ -419,10 +373,8 @@ trait FunctionSuite extends TestData { Seq( Row(0.10033467208545055, 0.10033467208545055, 0.09966799462495582, 0.09966799462495582), Row(0.2027100355086725, 0.2027100355086725, 0.197375320224904, 0.197375320224904), - Row(0.30933624960962325, 0.30933624960962325, 0.2913126124515909, 0.2913126124515909) - ), - sort = false - ) + Row(0.30933624960962325, 0.30933624960962325, 0.2913126124515909, 0.2913126124515909)), + sort = false) } test("degrees") { @@ -431,10 +383,8 @@ trait FunctionSuite extends TestData { Seq( Row(5.729577951308233, 28.64788975654116), Row(11.459155902616466, 34.37746770784939), - Row(17.188733853924695, 40.10704565915762) - ), - sort = false - ) + Row(17.188733853924695, 40.10704565915762)), + sort = false) } test("radians") { @@ -443,10 +393,8 @@ trait FunctionSuite extends TestData { Seq( Row(0.019390607989657, 0.019390607989657), Row(0.038781215979314, 0.038781215979314), - Row(0.058171823968971005, 0.058171823968971005) - ), - sort = false - ) + Row(0.058171823968971005, 0.058171823968971005)), + sort = false) } test("md5 sha1 sha2") { @@ -469,16 +417,14 @@ trait FunctionSuite extends TestData { "d2d5c076b2435565f66649edd604dd5987163e8a8240953144ec652f" ) ), // pragma: allowlist secret - sort = false - ) + sort = false) } test("hash") { checkAnswer( string1.select(hash(col("A"))), Seq(Row(-1996792119384707157L), Row(-410379000639015509L), Row(9028932499781431792L)), - sort = false - ) + sort = false) } test("ascii") { @@ -489,47 +435,41 @@ trait FunctionSuite extends TestData { checkAnswer( string1.select(concat_ws(lit(","), col("A"), col("B"))), Seq(Row("test1,a"), Row("test2,b"), Row("test3,c")), - sort = false - ) + sort = false) } test("initcap length lower upper") { checkAnswer( string2.select(initcap(col("A")), length(col("A")), lower(col("A")), upper(col("A"))), Seq(Row("Asdfg", 5, "asdfg", "ASDFG"), Row("Qqq", 3, "qqq", "QQQ"), Row("Qw", 2, "qw", "QW")), - sort = false - ) + sort = false) } test("lpad rpad") { checkAnswer( string2.select(lpad(col("A"), lit(8), lit("X")), rpad(col("A"), lit(9), lit("S"))), Seq(Row("XXXasdFg", "asdFgSSSS"), Row("XXXXXqqq", "qqqSSSSSS"), Row("XXXXXXQw", "QwSSSSSSS")), - sort = false - ) + sort = false) } test("ltrim rtrim, trim") { checkAnswer( string3.select(ltrim(col("A")), rtrim(col("A"))), Seq(Row("abcba ", " abcba"), Row("a12321a ", " a12321a")), - sort = false - ) + sort = false) checkAnswer( string3 .select(ltrim(col("A"), lit(" a")), rtrim(col("A"), lit(" a")), trim(col("A"), lit("a "))), Seq(Row("bcba ", " abcb", "bcb"), Row("12321a ", " a12321", "12321")), - sort = false - ) + sort = false) } test("repeat") { checkAnswer( string1.select(repeat(col("B"), lit(3))), Seq(Row("aaa"), Row("bbb"), Row("ccc")), - sort = false - ) + sort = false) } test("builtin function") { @@ -537,40 +477,35 @@ trait FunctionSuite extends TestData { checkAnswer( string1.select(repeat(col("B"), 3)), Seq(Row("aaa"), Row("bbb"), Row("ccc")), - sort = false - ) + sort = false) } test("soundex") { checkAnswer( string4.select(soundex(col("A"))), Seq(Row("a140"), Row("b550"), Row("p200")), - sort = false - ) + sort = false) } test("sub string") { checkAnswer( string1.select(substring(col("A"), lit(2), lit(4))), Seq(Row("est1"), Row("est2"), Row("est3")), - sort = false - ) + sort = false) } test("translate") { checkAnswer( string3.select(translate(col("A"), lit("ab "), lit("XY"))), Seq(Row("XYcYX"), Row("X12321X")), - sort = false - ) + sort = false) } test("add months, current date") { checkAnswer( date1.select(add_months(col("A"), lit(1))), Seq(Row(Date.valueOf("2020-09-01")), Row(Date.valueOf("2011-01-01"))), - sort = false - ) + sort = false) // zero1.select(current_date()) gets the date on server, which uses session timezone. // System.currentTimeMillis() is based on jvm timezone. They should not always be equal. // We can set local JVM timezone to session timezone to ensure it passes. @@ -579,31 +514,25 @@ trait FunctionSuite extends TestData { { checkAnswer(zero1.select(current_date()), Seq(Row(new Date(System.currentTimeMillis())))) }, - getTimeZone(session) - ), + getTimeZone(session)), "TIMEZONE", - "'GMT'" - ) + "'GMT'") testWithAlteredSessionParameter( testWithTimezone( { checkAnswer(zero1.select(current_date()), Seq(Row(new Date(System.currentTimeMillis())))) }, - getTimeZone(session) - ), + getTimeZone(session)), "TIMEZONE", - "'Etc/GMT+8'" - ) + "'Etc/GMT+8'") testWithAlteredSessionParameter( testWithTimezone( { checkAnswer(zero1.select(current_date()), Seq(Row(new Date(System.currentTimeMillis())))) }, - getTimeZone(session) - ), + getTimeZone(session)), "TIMEZONE", - "'Etc/GMT-8'" - ) + "'Etc/GMT-8'") } test("current timestamp") { @@ -612,8 +541,7 @@ trait FunctionSuite extends TestData { .select(current_timestamp()) .collect()(0) .getTimestamp(0) - .getTime).abs < 100000 - ) + .getTime).abs < 100000) } test("year month day week quarter") { @@ -626,22 +554,18 @@ trait FunctionSuite extends TestData { dayofyear(col("A")), quarter(col("A")), weekofyear(col("A")), - last_day(col("A")) - ), + last_day(col("A"))), Seq( Row(2020, 8, 1, 6, 214, 3, 31, new Date(120, 7, 31)), - Row(2010, 12, 1, 3, 335, 4, 48, new Date(110, 11, 31)) - ), - sort = false - ) + Row(2010, 12, 1, 3, 335, 4, 48, new Date(110, 11, 31))), + sort = false) } test("next day") { checkAnswer( date1.select(next_day(col("A"), lit("FR"))), Seq(Row(new Date(120, 7, 7)), Row(new Date(110, 11, 3))), - sort = false - ) + sort = false) } test("previous day") { @@ -649,16 +573,14 @@ trait FunctionSuite extends TestData { checkAnswer( date2.select(previous_day(col("a"), col("b"))), Seq(Row(new Date(120, 6, 27)), Row(new Date(110, 10, 24))), - sort = false - ) + sort = false) } test("hour minute second") { checkAnswer( timestamp1.select(hour(col("A")), minute(col("A")), second(col("A"))), Seq(Row(13, 11, 20), Row(1, 30, 5)), - sort = false - ) + sort = false) } test("datediff") { @@ -666,16 +588,14 @@ trait FunctionSuite extends TestData { timestamp1 .select(col("a"), dateadd("year", lit(1), col("a")).as("b")) .select(datediff("year", col("a"), col("b"))), - Seq(Row(1), Row(1)) - ) + Seq(Row(1), Row(1))) } test("dateadd") { checkAnswer( date1.select(dateadd("year", lit(1), col("a"))), Seq(Row(new Date(121, 7, 1)), Row(new Date(111, 11, 1))), - sort = false - ) + sort = false) } test("to_timestamp") { @@ -684,41 +604,34 @@ trait FunctionSuite extends TestData { Seq( Row(Timestamp.valueOf("2019-06-25 16:19:17.0")), Row(Timestamp.valueOf("2019-08-10 23:25:57.0")), - Row(Timestamp.valueOf("2006-10-22 01:12:37.0")) - ), - sort = false - ) + Row(Timestamp.valueOf("2006-10-22 01:12:37.0"))), + sort = false) val df = session.sql("select * from values('04/05/2020 01:02:03') as T(a)") checkAnswer( df.select(to_timestamp(col("A"), lit("mm/dd/yyyy hh24:mi:ss"))), - Seq(Row(Timestamp.valueOf("2020-04-05 01:02:03.0"))) - ) + Seq(Row(Timestamp.valueOf("2020-04-05 01:02:03.0")))) } test("convert_timezone") { checkAnswer( timestampNTZ.select( - convert_timezone(lit("America/Los_Angeles"), lit("America/New_York"), col("a")) - ), + convert_timezone(lit("America/Los_Angeles"), lit("America/New_York"), col("a"))), Seq( Row(Timestamp.valueOf("2020-05-01 16:11:20.0")), - Row(Timestamp.valueOf("2020-08-21 04:30:05.0")) - ), - sort = false - ) + Row(Timestamp.valueOf("2020-08-21 04:30:05.0"))), + sort = false) val df = Seq(("2020-05-01 16:11:20.0 +02:00", "2020-08-21 04:30:05.0 -06:00")).toDF("a", "b") checkAnswer( df.select( convert_timezone(lit("America/Los_Angeles"), col("a")), - convert_timezone(lit("America/New_York"), col("b")) - ), + convert_timezone(lit("America/New_York"), col("b"))), Seq( - Row(Timestamp.valueOf("2020-05-01 07:11:20.0"), Timestamp.valueOf("2020-08-21 06:30:05.0")) - ) - ) + Row( + Timestamp.valueOf("2020-05-01 07:11:20.0"), + Timestamp.valueOf("2020-08-21 06:30:05.0")))) // -06:00 -> New_York should be -06:00 -> -04:00, which is +2 hours. } @@ -735,10 +648,8 @@ trait FunctionSuite extends TestData { timestamp1.select(date_trunc("quarter", col("A"))), Seq( Row(Timestamp.valueOf("2020-04-01 00:00:00.0")), - Row(Timestamp.valueOf("2020-07-01 00:00:00.0")) - ), - sort = false - ) + Row(Timestamp.valueOf("2020-07-01 00:00:00.0"))), + sort = false) } test("trunc") { @@ -750,8 +661,7 @@ trait FunctionSuite extends TestData { checkAnswer( string1.select(concat(col("A"), col("B"))), Seq(Row("test1a"), Row("test2b"), Row("test3c")), - sort = false - ) + sort = false) } test("split") { @@ -760,24 +670,21 @@ trait FunctionSuite extends TestData { .select(split(col("A"), lit(","))) .collect()(0) .getString(0) - .replaceAll("[ \n]", "") == "[\"1\",\"2\",\"3\",\"4\",\"5\"]" - ) + .replaceAll("[ \n]", "") == "[\"1\",\"2\",\"3\",\"4\",\"5\"]") } test("contains") { checkAnswer( string4.select(contains(col("a"), lit("app"))), Seq(Row(true), Row(false), Row(false)), - sort = false - ) + sort = false) } test("startswith") { checkAnswer( string4.select(startswith(col("a"), lit("ban"))), Seq(Row(false), Row(true), Row(false)), - sort = false - ) + sort = false) } test("char") { @@ -786,8 +693,7 @@ trait FunctionSuite extends TestData { checkAnswer( df.select(char(col("A")), char(col("B"))), Seq(Row("T", "U"), Row("`", "a")), - sort = false - ) + sort = false) } test("array_overlap") { @@ -795,16 +701,14 @@ trait FunctionSuite extends TestData { array1 .select(arrays_overlap(col("ARR1"), col("ARR2"))), Seq(Row(true), Row(false)), - sort = false - ) + sort = false) } test("array_intersection") { checkAnswer( array1.select(array_intersection(col("ARR1"), col("ARR2"))), Seq(Row("[\n 3\n]"), Row("[]")), - sort = false - ) + sort = false) } test("is_array") { @@ -812,54 +716,46 @@ trait FunctionSuite extends TestData { checkAnswer( variant1.select(is_array(col("arr1")), is_array(col("bool1")), is_array(col("str1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) } test("is_boolean") { checkAnswer( variant1.select(is_boolean(col("arr1")), is_boolean(col("bool1")), is_boolean(col("str1"))), Seq(Row(false, true, false)), - sort = false - ) + sort = false) } test("is_binary") { checkAnswer( variant1.select(is_binary(col("bin1")), is_binary(col("bool1")), is_binary(col("str1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) } test("is_char/is_varchar") { checkAnswer( variant1.select(is_char(col("str1")), is_char(col("bin1")), is_char(col("bool1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) checkAnswer( variant1.select(is_varchar(col("str1")), is_varchar(col("bin1")), is_varchar(col("bool1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) } test("is_date/is_date_value") { checkAnswer( variant1.select(is_date(col("date1")), is_date(col("time1")), is_date(col("bool1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) checkAnswer( variant1.select( is_date_value(col("date1")), is_date_value(col("time1")), - is_date_value(col("str1")) - ), + is_date_value(col("str1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) } test("is_decimal") { @@ -867,8 +763,7 @@ trait FunctionSuite extends TestData { variant1 .select(is_decimal(col("decimal1")), is_decimal(col("double1")), is_decimal(col("num1"))), Seq(Row(true, false, true)), - sort = false - ) + sort = false) } test("is_double/is_real") { @@ -877,22 +772,18 @@ trait FunctionSuite extends TestData { is_double(col("decimal1")), is_double(col("double1")), is_double(col("num1")), - is_double(col("bool1")) - ), + is_double(col("bool1"))), Seq(Row(true, true, true, false)), - sort = false - ) + sort = false) checkAnswer( variant1.select( is_real(col("decimal1")), is_real(col("double1")), is_real(col("num1")), - is_real(col("bool1")) - ), + is_real(col("bool1"))), Seq(Row(true, true, true, false)), - sort = false - ) + sort = false) } test("is_integer") { @@ -901,27 +792,23 @@ trait FunctionSuite extends TestData { is_integer(col("decimal1")), is_integer(col("double1")), is_integer(col("num1")), - is_integer(col("bool1")) - ), + is_integer(col("bool1"))), Seq(Row(false, false, true, false)), - sort = false - ) + sort = false) } test("is_null_value") { checkAnswer( nullJson1.select(is_null_value(sqlExpr("v:a"))), Seq(Row(true), Row(false), Row(null)), - sort = false - ) + sort = false) } test("is_object") { checkAnswer( variant1.select(is_object(col("obj1")), is_object(col("arr1")), is_object(col("str1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) } test("is_time") { @@ -929,8 +816,7 @@ trait FunctionSuite extends TestData { variant1 .select(is_time(col("time1")), is_time(col("date1")), is_time(col("timestamp_tz1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) } test("is_timestamp_*") { @@ -938,31 +824,25 @@ trait FunctionSuite extends TestData { variant1.select( is_timestamp_ntz(col("timestamp_ntz1")), is_timestamp_ntz(col("timestamp_tz1")), - is_timestamp_ntz(col("timestamp_ltz1")) - ), + is_timestamp_ntz(col("timestamp_ltz1"))), Seq(Row(true, false, false)), - sort = false - ) + sort = false) checkAnswer( variant1.select( is_timestamp_ltz(col("timestamp_ntz1")), is_timestamp_ltz(col("timestamp_tz1")), - is_timestamp_ltz(col("timestamp_ltz1")) - ), + is_timestamp_ltz(col("timestamp_ltz1"))), Seq(Row(false, false, true)), - sort = false - ) + sort = false) checkAnswer( variant1.select( is_timestamp_tz(col("timestamp_ntz1")), is_timestamp_tz(col("timestamp_tz1")), - is_timestamp_tz(col("timestamp_ltz1")) - ), + is_timestamp_tz(col("timestamp_ltz1"))), Seq(Row(false, true, false)), - sort = false - ) + sort = false) } test("current_region") { @@ -1000,8 +880,7 @@ trait FunctionSuite extends TestData { .collect()(0) .getString(0) .trim - .startsWith("SELECT current_statement()") - ) + .startsWith("SELECT current_statement()")) } test("current_available_roles") { @@ -1027,8 +906,7 @@ trait FunctionSuite extends TestData { .select(current_user()) .collect()(0) .getString(0) - .equalsIgnoreCase(getUserFromProperties) - ) + .equalsIgnoreCase(getUserFromProperties)) } test("current_database") { @@ -1037,8 +915,7 @@ trait FunctionSuite extends TestData { .select(current_database()) .collect()(0) .getString(0) - .equalsIgnoreCase(getDatabaseFromProperties.replaceAll("""^"|"$""", "")) - ) + .equalsIgnoreCase(getDatabaseFromProperties.replaceAll("""^"|"$""", ""))) } test("current_schema") { @@ -1047,8 +924,7 @@ trait FunctionSuite extends TestData { .select(current_schema()) .collect()(0) .getString(0) - .equalsIgnoreCase(getSchemaFromProperties.replaceAll("""^"|"$""", "")) - ) + .equalsIgnoreCase(getSchemaFromProperties.replaceAll("""^"|"$""", ""))) } test("current_schemas") { @@ -1070,8 +946,7 @@ trait FunctionSuite extends TestData { .select(current_warehouse()) .collect()(0) .getString(0) - .equalsIgnoreCase(getWarehouseFromProperties.replaceAll("""^"|"$""", "")) - ) + .equalsIgnoreCase(getWarehouseFromProperties.replaceAll("""^"|"$""", ""))) } test("date_from_parts") { @@ -1096,24 +971,21 @@ trait FunctionSuite extends TestData { checkAnswer( string4.select(insert(col("a"), lit(2), lit(3), lit("abc"))), Seq(Row("aabce"), Row("babcna"), Row("pabch")), - sort = false - ) + sort = false) } test("left") { checkAnswer( string4.select(left(col("a"), lit(2))), Seq(Row("ap"), Row("ba"), Row("pe")), - sort = false - ) + sort = false) } test("right") { checkAnswer( string4.select(right(col("a"), lit(2))), Seq(Row("le"), Row("na"), Row("ch")), - sort = false - ) + sort = false) } test("sysdate") { @@ -1123,36 +995,31 @@ trait FunctionSuite extends TestData { .collect()(0) .getTimestamp(0) .toString - .length > 0 - ) + .length > 0) } test("regexp_count") { checkAnswer( string4.select(regexp_count(col("a"), lit("a"))), Seq(Row(1), Row(3), Row(1)), - sort = false - ) + sort = false) checkAnswer( string4.select(regexp_count(col("a"), lit("a"), lit(2), lit("c"))), Seq(Row(0), Row(3), Row(1)), - sort = false - ) + sort = false) } test("replace") { checkAnswer( string4.select(replace(col("a"), lit("a"))), Seq(Row("pple"), Row("bnn"), Row("pech")), - sort = false - ) + sort = false) checkAnswer( string4.select(replace(col("a"), lit("a"), lit("z"))), Seq(Row("zpple"), Row("bznznz"), Row("pezch")), - sort = false - ) + sort = false) } @@ -1162,30 +1029,26 @@ trait FunctionSuite extends TestData { .select(time_from_parts(lit(1), lit(2), lit(3))) .collect()(0) .getTime(0) - .equals(new Time(3723000)) - ) + .equals(new Time(3723000))) assert( zero1 .select(time_from_parts(lit(1), lit(2), lit(3), lit(444444444))) .collect()(0) .getTime(0) - .equals(new Time(3723444)) - ) + .equals(new Time(3723444))) } test("charindex") { checkAnswer( string4.select(charindex(lit("na"), col("a"))), Seq(Row(0), Row(3), Row(0)), - sort = false - ) + sort = false) checkAnswer( string4.select(charindex(lit("na"), col("a"), lit(4))), Seq(Row(0), Row(5), Row(0)), - sort = false - ) + sort = false) } test("collate") { @@ -1210,13 +1073,10 @@ trait FunctionSuite extends TestData { col("day"), col("hour"), col("minute"), - col("second") - ) - ) + col("second"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.0" - ) + .toString == "2020-10-28 13:35:47.0") assert( date3 @@ -1228,22 +1088,16 @@ trait FunctionSuite extends TestData { col("hour"), col("minute"), col("second"), - col("nanosecond") - ) - ) + col("nanosecond"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.001234567" - ) + .toString == "2020-10-28 13:35:47.001234567") assert( date3 - .select( - timestamp_from_parts( - date_from_parts(col("year"), col("month"), col("day")), - time_from_parts(col("hour"), col("minute"), col("second"), col("nanosecond")) - ) - ) + .select(timestamp_from_parts( + date_from_parts(col("year"), col("month"), col("day")), + time_from_parts(col("hour"), col("minute"), col("second"), col("nanosecond")))) .collect()(0) .getTimestamp(0) .toString == "2020-10-28 13:35:47.001234567") @@ -1259,13 +1113,10 @@ trait FunctionSuite extends TestData { col("day"), col("hour"), col("minute"), - col("second") - ) - ) + col("second"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.0" - ) + .toString == "2020-10-28 13:35:47.0") assert( date3 @@ -1277,13 +1128,10 @@ trait FunctionSuite extends TestData { col("hour"), col("minute"), col("second"), - col("nanosecond") - ) - ) + col("nanosecond"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.001234567" - ) + .toString == "2020-10-28 13:35:47.001234567") } test("timestamp_ntz_from_parts") { @@ -1296,13 +1144,10 @@ trait FunctionSuite extends TestData { col("day"), col("hour"), col("minute"), - col("second") - ) - ) + col("second"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.0" - ) + .toString == "2020-10-28 13:35:47.0") assert( date3 @@ -1314,26 +1159,19 @@ trait FunctionSuite extends TestData { col("hour"), col("minute"), col("second"), - col("nanosecond") - ) - ) + col("nanosecond"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.001234567" - ) + .toString == "2020-10-28 13:35:47.001234567") assert( date3 - .select( - timestamp_ntz_from_parts( - date_from_parts(col("year"), col("month"), col("day")), - time_from_parts(col("hour"), col("minute"), col("second"), col("nanosecond")) - ) - ) + .select(timestamp_ntz_from_parts( + date_from_parts(col("year"), col("month"), col("day")), + time_from_parts(col("hour"), col("minute"), col("second"), col("nanosecond")))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.001234567" - ) + .toString == "2020-10-28 13:35:47.001234567") } test("timestamp_tz_from_parts") { @@ -1346,13 +1184,10 @@ trait FunctionSuite extends TestData { col("day"), col("hour"), col("minute"), - col("second") - ) - ) + col("second"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.0" - ) + .toString == "2020-10-28 13:35:47.0") assert( date3 @@ -1364,13 +1199,10 @@ trait FunctionSuite extends TestData { col("hour"), col("minute"), col("second"), - col("nanosecond") - ) - ) + col("nanosecond"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.001234567" - ) + .toString == "2020-10-28 13:35:47.001234567") assert( date3 @@ -1383,13 +1215,10 @@ trait FunctionSuite extends TestData { col("minute"), col("second"), col("nanosecond"), - col("timezone") - ) - ) + col("timezone"))) .collect()(0) .getTimestamp(0) - .toString == "2020-10-28 13:35:47.001234567" - ) + .toString == "2020-10-28 13:35:47.001234567") assert( date3 @@ -1402,13 +1231,10 @@ trait FunctionSuite extends TestData { col("minute"), col("second"), col("nanosecond"), - lit("America/New_York") - ) - ) + lit("America/New_York"))) .collect()(0) .getTimestamp(0) - .toGMTString == "28 Oct 2020 17:35:47 GMT" - ) + .toGMTString == "28 Oct 2020 17:35:47 GMT") } @@ -1416,52 +1242,44 @@ trait FunctionSuite extends TestData { checkAnswer( nullJson1.select(check_json(col("v"))), Seq(Row(null), Row(null), Row(null)), - sort = false - ) + sort = false) checkAnswer( invalidJson1.select(check_json(col("v"))), Seq( Row("incomplete object value, pos 11"), Row("missing colon, pos 7"), - Row("unfinished string, pos 5") - ), - sort = false - ) + Row("unfinished string, pos 5")), + sort = false) } test("check_xml") { checkAnswer( nullXML1.select(check_xml(col("v"))), Seq(Row(null), Row(null), Row(null), Row(null)), - sort = false - ) + sort = false) checkAnswer( invalidXML1.select(check_xml(col("v"))), Seq( Row("no opening tag for , pos 8"), Row("missing closing tags: , pos 8"), - Row("bad character in XML tag name: '<', pos 4") - ), - sort = false - ) + Row("bad character in XML tag name: '<', pos 4")), + sort = false) } test("json_extract_path_text") { checkAnswer( validJson1.select(json_extract_path_text(col("v"), col("k"))), Seq(Row(null), Row("foo"), Row(null), Row(null)), - sort = false - ) + sort = false) } test("parse_json") { checkAnswer( nullJson1.select(parse_json(col("v"))), Seq(Row("{\n \"a\": null\n}"), Row("{\n \"a\": \"foo\"\n}"), Row(null)), - sort = false - ) + sort = false) } test("parse_xml") { @@ -1471,32 +1289,27 @@ trait FunctionSuite extends TestData { Row("\n foo\n bar\n \n"), Row(""), Row(null), - Row(null) - ), - sort = false - ) + Row(null)), + sort = false) } test("strip_null_value") { checkAnswer( nullJson1.select(sqlExpr("v:a")), Seq(Row("null"), Row("\"foo\""), Row(null)), - sort = false - ) + sort = false) checkAnswer( nullJson1.select(strip_null_value(sqlExpr("v:a"))), Seq(Row(null), Row("\"foo\""), Row(null)), - sort = false - ) + sort = false) } test("array_agg") { assert( monthlySales.select(array_agg(col("amount"))).collect()(0).get(0).toString == "[\n 10000,\n 400,\n 4500,\n 35000,\n 5000,\n 3000,\n 200,\n 90500,\n 6000,\n " + - "5000,\n 2500,\n 9500,\n 8000,\n 10000,\n 800,\n 4500\n]" - ) + "5000,\n 2500,\n 9500,\n 8000,\n 10000,\n 800,\n 4500\n]") } test("array_agg WITHIN GROUP") { @@ -1507,8 +1320,7 @@ trait FunctionSuite extends TestData { .get(0) .toString == "[\n 200,\n 400,\n 800,\n 2500,\n 3000,\n 4500,\n 4500,\n 5000,\n 5000,\n " + - "6000,\n 8000,\n 9500,\n 10000,\n 10000,\n 35000,\n 90500\n]" - ) + "6000,\n 8000,\n 9500,\n 10000,\n 10000,\n 35000,\n 90500\n]") } test("array_agg WITHIN GROUP ORDER BY DESC") { @@ -1519,8 +1331,7 @@ trait FunctionSuite extends TestData { .get(0) .toString == "[\n 90500,\n 35000,\n 10000,\n 10000,\n 9500,\n 8000,\n 6000,\n 5000,\n" + - " 5000,\n 4500,\n 4500,\n 3000,\n 2500,\n 800,\n 400,\n 200\n]" - ) + " 5000,\n 4500,\n 4500,\n 3000,\n 2500,\n 800,\n 400,\n 200\n]") } test("array_agg WITHIN GROUP ORDER BY multiple columns") { @@ -1533,8 +1344,7 @@ trait FunctionSuite extends TestData { .select(array_agg(col("amount")).withinGroup(sortColumns)) .collect()(0) .get(0) - .toString == expected - ) + .toString == expected) } test("window function array_agg WITHIN GROUP") { @@ -1544,11 +1354,9 @@ trait FunctionSuite extends TestData { xyz.select( array_agg(col("Z")) .withinGroup(Seq(col("Z"), col("Y"))) - .over(Window.partitionBy(col("X"))) - ), + .over(Window.partitionBy(col("X")))), Seq(Row(value1), Row(value1), Row(value2), Row(value2), Row(value2)), - false - ) + false) } test("array_append") { @@ -1556,10 +1364,8 @@ trait FunctionSuite extends TestData { array1.select(array_append(array_append(col("arr1"), lit("amount")), lit(32.21))), Seq( Row("[\n 1,\n 2,\n 3,\n \"amount\",\n 3.221000000000000e+01\n]"), - Row("[\n 6,\n 7,\n 8,\n \"amount\",\n 3.221000000000000e+01\n]") - ), - sort = false - ) + Row("[\n 6,\n 7,\n 8,\n \"amount\",\n 3.221000000000000e+01\n]")), + sort = false) // Get array result in List[Variant] val resultSet = @@ -1569,26 +1375,22 @@ trait FunctionSuite extends TestData { new Variant(2), new Variant(3), new Variant("amount"), - new Variant("3.221000000000000e+01") - ) + new Variant("3.221000000000000e+01")) assert(resultSet(0).getSeqOfVariant(0).equals(row1)) val row2 = Seq( new Variant(6), new Variant(7), new Variant(8), new Variant("amount"), - new Variant("3.221000000000000e+01") - ) + new Variant("3.221000000000000e+01")) assert(resultSet(1).getSeqOfVariant(0).equals(row2)) checkAnswer( array2.select(array_append(array_append(col("arr1"), col("d")), col("e"))), Seq( Row("[\n 1,\n 2,\n 3,\n 2,\n \"e1\"\n]"), - Row("[\n 6,\n 7,\n 8,\n 1,\n \"e2\"\n]") - ), - sort = false - ) + Row("[\n 6,\n 7,\n 8,\n 1,\n \"e2\"\n]")), + sort = false) } test("array_cat") { @@ -1596,18 +1398,15 @@ trait FunctionSuite extends TestData { array1.select(array_cat(col("arr1"), col("arr2"))), Seq( Row("[\n 1,\n 2,\n 3,\n 3,\n 4,\n 5\n]"), - Row("[\n 6,\n 7,\n 8,\n 9,\n 0,\n 1\n]") - ), - sort = false - ) + Row("[\n 6,\n 7,\n 8,\n 9,\n 0,\n 1\n]")), + sort = false) } test("array_compact") { checkAnswer( nullArray1.select(array_compact(col("arr1"))), Seq(Row("[\n 1,\n 3\n]"), Row("[\n 6,\n 8\n]")), - sort = false - ) + sort = false) } test("array_construct") { @@ -1616,16 +1415,14 @@ trait FunctionSuite extends TestData { .select(array_construct(lit(1), lit(1.2), lit("string"), lit(""), lit(null))) .collect()(0) .getString(0) == - "[\n 1,\n 1.200000000000000e+00,\n \"string\",\n \"\",\n undefined\n]" - ) + "[\n 1,\n 1.200000000000000e+00,\n \"string\",\n \"\",\n undefined\n]") assert( zero1 .select(array_construct()) .collect()(0) .getString(0) == - "[]" - ) + "[]") checkAnswer( integer1 @@ -1633,10 +1430,8 @@ trait FunctionSuite extends TestData { Seq( Row("[\n 1,\n 1.200000000000000e+00,\n undefined\n]"), Row("[\n 2,\n 1.200000000000000e+00,\n undefined\n]"), - Row("[\n 3,\n 1.200000000000000e+00,\n undefined\n]") - ), - sort = false - ) + Row("[\n 3,\n 1.200000000000000e+00,\n undefined\n]")), + sort = false) } test("array_construct_compact") { @@ -1645,16 +1440,14 @@ trait FunctionSuite extends TestData { .select(array_construct_compact(lit(1), lit(1.2), lit("string"), lit(""), lit(null))) .collect()(0) .getString(0) == - "[\n 1,\n 1.200000000000000e+00,\n \"string\",\n \"\"\n]" - ) + "[\n 1,\n 1.200000000000000e+00,\n \"string\",\n \"\"\n]") assert( zero1 .select(array_construct_compact()) .collect()(0) .getString(0) == - "[]" - ) + "[]") checkAnswer( integer1 @@ -1662,10 +1455,8 @@ trait FunctionSuite extends TestData { Seq( Row("[\n 1,\n 1.200000000000000e+00\n]"), Row("[\n 2,\n 1.200000000000000e+00\n]"), - Row("[\n 3,\n 1.200000000000000e+00\n]") - ), - sort = false - ) + Row("[\n 3,\n 1.200000000000000e+00\n]")), + sort = false) } test("array_contains") { @@ -1673,38 +1464,33 @@ trait FunctionSuite extends TestData { zero1 .select(array_contains(lit(1), array_construct(lit(1), lit(1.2), lit("string")))) .collect()(0) - .getBoolean(0) - ) + .getBoolean(0)) assert( !zero1 .select(array_contains(lit(-1), array_construct(lit(1), lit(1.2), lit("string")))) .collect()(0) - .getBoolean(0) - ) + .getBoolean(0)) checkAnswer( integer1 .select(array_contains(col("a"), array_construct(lit(1), lit(1.2), lit("string")))), Seq(Row(true), Row(false), Row(false)), - sort = false - ) + sort = false) } test("array_insert") { checkAnswer( array2.select(array_insert(col("arr1"), col("d"), col("e"))), Seq(Row("[\n 1,\n 2,\n \"e1\",\n 3\n]"), Row("[\n 6,\n \"e2\",\n 7,\n 8\n]")), - sort = false - ) + sort = false) } test("array_position") { checkAnswer( array2.select(array_position(col("d"), col("arr1"))), Seq(Row(1), Row(null)), - sort = false - ) + sort = false) } test("array_prepend") { @@ -1712,19 +1498,15 @@ trait FunctionSuite extends TestData { array1.select(array_prepend(array_prepend(col("arr1"), lit("amount")), lit(32.21))), Seq( Row("[\n 3.221000000000000e+01,\n \"amount\",\n 1,\n 2,\n 3\n]"), - Row("[\n 3.221000000000000e+01,\n \"amount\",\n 6,\n 7,\n 8\n]") - ), - sort = false - ) + Row("[\n 3.221000000000000e+01,\n \"amount\",\n 6,\n 7,\n 8\n]")), + sort = false) checkAnswer( array2.select(array_prepend(array_prepend(col("arr1"), col("d")), col("e"))), Seq( Row("[\n \"e1\",\n 2,\n 1,\n 2,\n 3\n]"), - Row("[\n \"e2\",\n 1,\n 6,\n 7,\n 8\n]") - ), - sort = false - ) + Row("[\n \"e2\",\n 1,\n 6,\n 7,\n 8\n]")), + sort = false) } test("array_size") { @@ -1739,32 +1521,28 @@ trait FunctionSuite extends TestData { checkAnswer( array3.select(array_slice(col("arr1"), col("d"), col("e"))), Seq(Row("[\n 2\n]"), Row("[\n 5\n]"), Row("[\n 6,\n 7\n]")), - sort = false - ) + sort = false) } test("array_to_string") { checkAnswer( array3.select(array_to_string(col("arr1"), col("f"))), Seq(Row("1,2,3"), Row("4, 5, 6"), Row("6;7;8")), - sort = false - ) + sort = false) } test("objectagg") { checkAnswer( object1.select(objectagg(col("key"), col("value"))), Seq(Row("{\n \"age\": 21,\n \"zip\": 94401\n}")), - sort = false - ) + sort = false) } test("object_construct") { checkAnswer( object1.select(object_construct(col("key"), col("value"))), Seq(Row("{\n \"age\": 21\n}"), Row("{\n \"zip\": 94401\n}")), - sort = false - ) + sort = false) checkAnswer(object1.select(object_construct()), Seq(Row("{}"), Row("{}")), sort = false) } @@ -1773,8 +1551,7 @@ trait FunctionSuite extends TestData { checkAnswer( object2.select(object_delete(col("obj"), col("k"), lit("name"), lit("non-exist-key"))), Seq(Row("{\n \"zip\": 21021\n}"), Row("{\n \"age\": 26,\n \"zip\": 94021\n}")), - sort = false - ) + sort = false) } test("object_insert") { @@ -1782,10 +1559,8 @@ trait FunctionSuite extends TestData { object2.select(object_insert(col("obj"), lit("key"), lit("v"))), Seq( Row("{\n \"age\": 21,\n \"key\": \"v\",\n \"name\": \"Joe\",\n \"zip\": 21021\n}"), - Row("{\n \"age\": 26,\n \"key\": \"v\",\n \"name\": \"Jay\",\n \"zip\": 94021\n}") - ), - sort = false - ) + Row("{\n \"age\": 26,\n \"key\": \"v\",\n \"name\": \"Jay\",\n \"zip\": 94021\n}")), + sort = false) // Get object result in Map[String, Variant] val resultSet = object2.select(object_insert(col("obj"), lit("key"), lit("v"))).collect() @@ -1793,84 +1568,71 @@ trait FunctionSuite extends TestData { "age" -> new Variant(21), "key" -> new Variant("v"), "name" -> new Variant("Joe"), - "zip" -> new Variant(21021) - ) + "zip" -> new Variant(21021)) assert(resultSet(0).getMapOfVariant(0).equals(row1)) val row2 = Map( "age" -> new Variant(26), "key" -> new Variant("v"), "name" -> new Variant("Jay"), - "zip" -> new Variant(94021) - ) + "zip" -> new Variant(94021)) assert(resultSet(1).getMapOfVariant(0).equals(row2)) checkAnswer( object2.select(object_insert(col("obj"), col("k"), col("v"), col("flag"))), Seq( Row("{\n \"age\": 0,\n \"name\": \"Joe\",\n \"zip\": 21021\n}"), - Row("{\n \"age\": 26,\n \"key\": 0,\n \"name\": \"Jay\",\n \"zip\": 94021\n}") - ), - sort = false - ) + Row("{\n \"age\": 26,\n \"key\": 0,\n \"name\": \"Jay\",\n \"zip\": 94021\n}")), + sort = false) } test("object_pick") { checkAnswer( object2.select(object_pick(col("obj"), col("k"), lit("name"), lit("non-exist-key"))), Seq(Row("{\n \"age\": 21,\n \"name\": \"Joe\"\n}"), Row("{\n \"name\": \"Jay\"\n}")), - sort = false - ) + sort = false) checkAnswer( object2.select(object_pick(col("obj"), array_construct(lit("name"), lit("zip")))), Seq( Row("{\n \"name\": \"Joe\",\n \"zip\": 21021\n}"), - Row("{\n \"name\": \"Jay\",\n \"zip\": 94021\n}") - ), - sort = false - ) + Row("{\n \"name\": \"Jay\",\n \"zip\": 94021\n}")), + sort = false) } test("as_array") { checkAnswer( array1.select(as_array(col("ARR1"))), Seq(Row("[\n 1,\n 2,\n 3\n]"), Row("[\n 6,\n 7,\n 8\n]")), - sort = false - ) + sort = false) checkAnswer( variant1.select(as_array(col("arr1")), as_array(col("bool1")), as_array(col("str1"))), Seq(Row("[\n \"Example\"\n]", null, null)), - sort = false - ) + sort = false) } test("as_binary") { checkAnswer( variant1.select(as_binary(col("bin1")), as_binary(col("bool1")), as_binary(col("str1"))), Seq(Row(Array[Byte](115, 110, 111, 119), null, null)), - sort = false - ) + sort = false) } test("as_char/as_varchar") { checkAnswer( variant1.select(as_char(col("str1")), as_char(col("bin1")), as_char(col("bool1"))), Seq(Row("X", null, null)), - sort = false - ) + sort = false) checkAnswer( variant1.select(as_varchar(col("str1")), as_varchar(col("bin1")), as_varchar(col("bool1"))), Seq(Row("X", null, null)), - sort = false - ) + sort = false) } test("as_date") { checkAnswer( variant1.select(as_date(col("date1")), as_date(col("time1")), as_date(col("bool1"))), Seq(Row(new Date(117, 1, 24), null, null)), - sort = false - ) + sort = false) } test("as_decimal/as_number") { @@ -1878,45 +1640,39 @@ trait FunctionSuite extends TestData { variant1 .select(as_decimal(col("decimal1")), as_decimal(col("double1")), as_decimal(col("num1"))), Seq(Row(1, null, 15)), - sort = false - ) + sort = false) assert( variant1 .select(as_decimal(col("decimal1"), 6)) .collect()(0) - .getLong(0) == 1 - ) + .getLong(0) == 1) assert( variant1 .select(as_decimal(col("decimal1"), 6, 3)) .collect()(0) .getDecimal(0) - .doubleValue() == 1.23 - ) + .doubleValue() == 1.23) checkAnswer( variant1 .select(as_number(col("decimal1")), as_number(col("double1")), as_number(col("num1"))), Seq(Row(1, null, 15)), - sort = false - ) + sort = false) assert( variant1 .select(as_number(col("decimal1"), 6)) .collect()(0) - .getLong(0) == 1 - ) + .getLong(0) == 1) assert( variant1 .select(as_number(col("decimal1"), 6, 3)) .collect()(0) .getDecimal(0) - .doubleValue() == 1.23 - ) + .doubleValue() == 1.23) } test("as_double/as_real") { @@ -1925,22 +1681,18 @@ trait FunctionSuite extends TestData { as_double(col("decimal1")), as_double(col("double1")), as_double(col("num1")), - as_double(col("bool1")) - ), + as_double(col("bool1"))), Seq(Row(1.23, 3.21, 15.0, null)), - sort = false - ) + sort = false) checkAnswer( variant1.select( as_real(col("decimal1")), as_real(col("double1")), as_real(col("num1")), - as_real(col("bool1")) - ), + as_real(col("bool1"))), Seq(Row(1.23, 3.21, 15.0, null)), - sort = false - ) + sort = false) } test("as_integer") { @@ -1949,19 +1701,16 @@ trait FunctionSuite extends TestData { as_integer(col("decimal1")), as_integer(col("double1")), as_integer(col("num1")), - as_integer(col("bool1")) - ), + as_integer(col("bool1"))), Seq(Row(1, null, 15, null)), - sort = false - ) + sort = false) } test("as_object") { checkAnswer( variant1.select(as_object(col("obj1")), as_object(col("arr1")), as_object(col("str1"))), Seq(Row("{\n \"Tree\": \"Pine\"\n}", null, null)), - sort = false - ) + sort = false) } test("as_time") { @@ -1969,8 +1718,7 @@ trait FunctionSuite extends TestData { variant1 .select(as_time(col("time1")), as_time(col("date1")), as_time(col("timestamp_tz1"))), Seq(Row(Time.valueOf("20:57:01"), null, null)), - sort = false - ) + sort = false) } test("as_timestamp_*") { @@ -1978,31 +1726,25 @@ trait FunctionSuite extends TestData { variant1.select( as_timestamp_ntz(col("timestamp_ntz1")), as_timestamp_ntz(col("timestamp_tz1")), - as_timestamp_ntz(col("timestamp_ltz1")) - ), + as_timestamp_ntz(col("timestamp_ltz1"))), Seq(Row(Timestamp.valueOf("2017-02-24 12:00:00.456"), null, null)), - sort = false - ) + sort = false) checkAnswer( variant1.select( as_timestamp_ltz(col("timestamp_ntz1")), as_timestamp_ltz(col("timestamp_tz1")), - as_timestamp_ltz(col("timestamp_ltz1")) - ), + as_timestamp_ltz(col("timestamp_ltz1"))), Seq(Row(null, null, Timestamp.valueOf("2017-02-24 04:00:00.123"))), - sort = false - ) + sort = false) checkAnswer( variant1.select( as_timestamp_tz(col("timestamp_ntz1")), as_timestamp_tz(col("timestamp_tz1")), - as_timestamp_tz(col("timestamp_ltz1")) - ), + as_timestamp_tz(col("timestamp_ltz1"))), Seq(Row(null, Timestamp.valueOf("2017-02-24 13:00:00.123"), null)), - sort = false - ) + sort = false) } test("strtok_to_array") { @@ -2011,20 +1753,16 @@ trait FunctionSuite extends TestData { .select(strtok_to_array(col("a"), col("b"))), Seq( Row("[\n \"1\",\n \"2\",\n \"3\",\n \"4\",\n \"5\"\n]"), - Row("[\n \"1\",\n \"2\",\n \"3\",\n \"4\",\n \"5\"\n]") - ), - sort = false - ) + Row("[\n \"1\",\n \"2\",\n \"3\",\n \"4\",\n \"5\"\n]")), + sort = false) checkAnswer( string6 .select(strtok_to_array(col("a"))), Seq( Row("[\n \"1,2,3,4,5\"\n]"), - Row("[\n \"1\",\n \"2\",\n \"3\",\n \"4\",\n \"5\"\n]") - ), - sort = false - ) + Row("[\n \"1\",\n \"2\",\n \"3\",\n \"4\",\n \"5\"\n]")), + sort = false) } test("to_array") { @@ -2032,8 +1770,7 @@ trait FunctionSuite extends TestData { integer1 .select(to_array(col("a"))), Seq(Row("[\n 1\n]"), Row("[\n 2\n]"), Row("[\n 3\n]")), - sort = false - ) + sort = false) } test("to_json") { @@ -2041,15 +1778,13 @@ trait FunctionSuite extends TestData { integer1 .select(to_json(col("a"))), Seq(Row("1"), Row("2"), Row("3")), - sort = false - ) + sort = false) checkAnswer( variant1 .select(to_json(col("time1"))), Seq(Row("\"20:57:01\"")), - sort = false - ) + sort = false) } test("to_object") { @@ -2057,8 +1792,7 @@ trait FunctionSuite extends TestData { variant1 .select(to_object(col("obj1"))), Seq(Row("{\n \"Tree\": \"Pine\"\n}")), - sort = false - ) + sort = false) } test("to_variant") { @@ -2066,8 +1800,7 @@ trait FunctionSuite extends TestData { integer1 .select(to_variant(col("a"))), Seq(Row("1"), Row("2"), Row("3")), - sort = false - ) + sort = false) assert(integer1.select(to_variant(col("a"))).collect()(0).getVariant(0).equals(new Variant(1))) } @@ -2078,10 +1811,8 @@ trait FunctionSuite extends TestData { Seq( Row("1"), Row("2"), - Row("3") - ), - sort = false - ) + Row("3")), + sort = false) } test("get") { @@ -2089,15 +1820,13 @@ trait FunctionSuite extends TestData { object2 .select(get(col("obj"), col("k"))), Seq(Row("21"), Row(null)), - sort = false - ) + sort = false) checkAnswer( object2 .select(get(col("obj"), lit("AGE"))), Seq(Row(null), Row(null)), - sort = false - ) + sort = false) } test("get_ignore_case") { @@ -2105,15 +1834,13 @@ trait FunctionSuite extends TestData { object2 .select(get(col("obj"), col("k"))), Seq(Row("21"), Row(null)), - sort = false - ) + sort = false) checkAnswer( object2 .select(get_ignore_case(col("obj"), lit("AGE"))), Seq(Row("21"), Row("26")), - sort = false - ) + sort = false) } test("object_keys") { @@ -2122,10 +1849,8 @@ trait FunctionSuite extends TestData { .select(object_keys(col("obj"))), Seq( Row("[\n \"age\",\n \"name\",\n \"zip\"\n]"), - Row("[\n \"age\",\n \"name\",\n \"zip\"\n]") - ), - sort = false - ) + Row("[\n \"age\",\n \"name\",\n \"zip\"\n]")), + sort = false) } test("xmlget") { @@ -2133,30 +1858,26 @@ trait FunctionSuite extends TestData { validXML1 .select(get(xmlget(col("v"), col("t2")), lit('$'))), Seq(Row("\"bar\""), Row(null), Row("\"foo\"")), - sort = false - ) + sort = false) assert( validXML1 .select(get(xmlget(col("v"), col("t2")), lit('$'))) .collect()(0) .getVariant(0) - .equals(new Variant("\"bar\"")) - ) + .equals(new Variant("\"bar\""))) checkAnswer( validXML1 .select(get(xmlget(col("v"), col("t3"), lit('0')), lit('@'))), Seq(Row("\"t3\""), Row(null), Row(null)), - sort = false - ) + sort = false) checkAnswer( validXML1 .select(get(xmlget(col("v"), col("t2"), col("instance")), lit('$'))), Seq(Row("\"bar\""), Row(null), Row("\"bar\"")), - sort = false - ) + sort = false) } test("get_path") { @@ -2164,8 +1885,7 @@ trait FunctionSuite extends TestData { validJson1 .select(get_path(col("v"), col("k"))), Seq(Row("null"), Row("\"foo\""), Row(null), Row(null)), - sort = false - ) + sort = false) } test("approx_percentile") { @@ -2184,19 +1904,15 @@ trait FunctionSuite extends TestData { "1.000000000000000e+00,\n 6.000000000000000e+00,\n 1.000000000000000e+00,\n " + "7.000000000000000e+00,\n 1.000000000000000e+00,\n 8.000000000000000e+00,\n " + "1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00\n ]," + - "\n \"type\": \"tdigest\",\n \"version\": 1\n}" - ) - ), - sort = false - ) + "\n \"type\": \"tdigest\",\n \"version\": 1\n}")), + sort = false) } test("approx_percentile_estimate") { checkAnswer( approxNumbers.select(approx_percentile_estimate(approx_percentile_accumulate(col("a")), 0.5)), approxNumbers.select(approx_percentile(col("a"), 0.5)), - sort = false - ) + sort = false) } test("approx_percentile_combine") { @@ -2223,11 +1939,8 @@ trait FunctionSuite extends TestData { "8.000000000000000e+00,\n 1.000000000000000e+00,\n 8.000000000000000e+00,\n " + "1.000000000000000e+00,\n 9.000000000000000e+00,\n 1.000000000000000e+00,\n " + "9.000000000000000e+00,\n 1.000000000000000e+00\n ],\n \"type\": \"tdigest\"," + - "\n \"version\": 1\n}" - ) - ), - sort = false - ) + "\n \"version\": 1\n}")), + sort = false) } test("toScalar(DataFrame) with SELECT") { @@ -2255,8 +1968,7 @@ trait FunctionSuite extends TestData { expectedResult = Seq(Row(1 + 3, 3 - 1), Row(2 + 3, 3 - 1)) checkAnswer( testData1.select(col("num") + toScalar(dfMax), toScalar(dfMax) - toScalar(dfMin)), - expectedResult - ) + expectedResult) } test("col(DataFrame) with SELECT") { @@ -2304,19 +2016,16 @@ trait FunctionSuite extends TestData { expectedResult = Seq(Row(2, false, "b")) checkAnswer( testData1.filter(col("num") > col(dfMin) && col("num") < col(dfMax)), - expectedResult - ) + expectedResult) expectedResult = Seq(Row(1, true, "a")) checkAnswer( testData1.filter(col("num") >= col(dfMin) && col("num") < col(dfMax) - 1), - expectedResult - ) + expectedResult) // empty result assert( testData1 .filter(col("num") < col(dfMin) && col("num") > col(dfMax)) - .count() === 0 - ) + .count() === 0) } test("col(DataFrame) negative test") { @@ -2348,8 +2057,7 @@ trait FunctionSuite extends TestData { ||false |12 |14 |14 | ||true |22 |24 |22 | |-------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) val df2 = df.select(df("b"), df("c"), df("d"), iff(col("b") === col("c"), df("b"), df("d"))) assert( @@ -2361,8 +2069,7 @@ trait FunctionSuite extends TestData { ||12 |12 |14 |12 | ||22 |23 |24 |24 | |---------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) } test("seq") { @@ -2376,38 +2083,32 @@ trait FunctionSuite extends TestData { session .generator(Byte.MaxValue.toInt + 10, Seq(seq1(false).as("a"))) .where(col("a") < 0) - .count() > 0 - ) + .count() > 0) assert( session .generator(Byte.MaxValue.toInt + 10, Seq(seq1().as("a"))) .where(col("a") < 0) - .count() == 0 - ) + .count() == 0) assert( session .generator(Short.MaxValue.toInt + 10, Seq(seq2(false).as("a"))) .where(col("a") < 0) - .count() > 0 - ) + .count() > 0) assert( session .generator(Short.MaxValue.toInt + 10, Seq(seq2().as("a"))) .where(col("a") < 0) - .count() == 0 - ) + .count() == 0) assert( session .generator(Int.MaxValue.toLong + 10, Seq(seq4(false).as("a"))) .where(col("a") < 0) - .count() > 0 - ) + .count() > 0) assert( session .generator(Int.MaxValue.toLong + 10, Seq(seq4().as("a"))) .where(col("a") < 0) - .count() == 0 - ) + .count() == 0) // do not test the wrap-around of seq8, too costly. // test range @@ -2415,26 +2116,22 @@ trait FunctionSuite extends TestData { session .generator(Byte.MaxValue.toLong + 10, Seq(seq1(false).as("a"))) .where(col("a") > Byte.MaxValue) - .count() == 0 - ) + .count() == 0) assert( session .generator(Byte.MaxValue.toLong + 10, Seq(seq2(false).as("a"))) .where(col("a") > Byte.MaxValue) - .count() > 0 - ) + .count() > 0) assert( session .generator(Short.MaxValue.toLong + 10, Seq(seq4(false).as("a"))) .where(col("a") > Short.MaxValue) - .count() > 0 - ) + .count() > 0) assert( session .generator(Int.MaxValue.toLong + 10, Seq(seq8(false).as("a"))) .where(col("a") > Int.MaxValue) - .count() > 0 - ) + .count() > 0) } test("uniform") { @@ -2452,33 +2149,25 @@ trait FunctionSuite extends TestData { checkAnswer( df.select( listagg(df.col("col")) - .withinGroup(df.col("col").asc) - ), - Seq(Row("122345")) - ) + .withinGroup(df.col("col").asc)), + Seq(Row("122345"))) checkAnswer( df.select( listagg(df.col("col"), ",") - .withinGroup(df.col("col").asc) - ), - Seq(Row("1,2,2,3,4,5")) - ) + .withinGroup(df.col("col").asc)), + Seq(Row("1,2,2,3,4,5"))) checkAnswer( df.select( listagg(df.col("col"), ",", isDistinct = true) - .withinGroup(df.col("col").asc) - ), - Seq(Row("1,2,3,4,5")) - ) + .withinGroup(df.col("col").asc)), + Seq(Row("1,2,3,4,5"))) // delimiter is ' checkAnswer( df.select( listagg(df.col("col"), "'", isDistinct = true) - .withinGroup(df.col("col").asc) - ), - Seq(Row("1'2'3'4'5")) - ) + .withinGroup(df.col("col").asc)), + Seq(Row("1'2'3'4'5"))) } test("regexp_replace") { @@ -2492,8 +2181,7 @@ trait FunctionSuite extends TestData { checkAnswer( data.select(regexp_replace(data("a"), pattern, replacement)), expected, - sort = false - ) + sort = false) } test("regexp_extract") { val data = Seq("A MAN A PLAN A CANAL").toDF("a") @@ -2590,8 +2278,7 @@ trait FunctionSuite extends TestData { checkAnswer( input.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")), expected, - sort = false - ) + sort = false) } test("last function") { @@ -2603,8 +2290,7 @@ trait FunctionSuite extends TestData { checkAnswer( input.select(last(col("name")).over(window).as("last_score_name")), expected, - sort = false - ) + sort = false) } test("log10 Column function") { diff --git a/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala b/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala index 4d1a5bff..7ffe6950 100644 --- a/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala @@ -11,17 +11,14 @@ class IndependentClassSuite extends FunSuite { test("scala variant") { checkDependencies( "target/classes/com/snowflake/snowpark/types/Variant.class", - Seq("com.snowflake.snowpark.types.Variant") - ) + Seq("com.snowflake.snowpark.types.Variant")) checkDependencies( "target/classes/com/snowflake/snowpark/types/Variant$.class", Seq( "com.snowflake.snowpark.types.Variant", "com.snowflake.snowpark.types.Geography", - "com.snowflake.snowpark.types.Geometry" - ) - ) + "com.snowflake.snowpark.types.Geometry")) } test("java variant") { @@ -29,47 +26,39 @@ class IndependentClassSuite extends FunSuite { "target/classes/com/snowflake/snowpark_java/types/Variant.class", Seq( "com.snowflake.snowpark_java.types.Variant", - "com.snowflake.snowpark_java.types.Geography" - ) - ) + "com.snowflake.snowpark_java.types.Geography")) } test("scala geography") { checkDependencies( "target/classes/com/snowflake/snowpark/types/Geography.class", - Seq("com.snowflake.snowpark.types.Geography") - ) + Seq("com.snowflake.snowpark.types.Geography")) checkDependencies( "target/classes/com/snowflake/snowpark/types/Geography$.class", - Seq("com.snowflake.snowpark.types.Geography") - ) + Seq("com.snowflake.snowpark.types.Geography")) } test("java geography") { checkDependencies( "target/classes/com/snowflake/snowpark_java/types/Geography.class", - Seq("com.snowflake.snowpark_java.types.Geography") - ) + Seq("com.snowflake.snowpark_java.types.Geography")) } test("scala geometry") { checkDependencies( "target/classes/com/snowflake/snowpark/types/Geometry.class", - Seq("com.snowflake.snowpark.types.Geometry") - ) + Seq("com.snowflake.snowpark.types.Geometry")) checkDependencies( "target/classes/com/snowflake/snowpark/types/Geometry$.class", - Seq("com.snowflake.snowpark.types.Geometry") - ) + Seq("com.snowflake.snowpark.types.Geometry")) } test("java geometry") { checkDependencies( "target/classes/com/snowflake/snowpark_java/types/Geometry.class", - Seq("com.snowflake.snowpark_java.types.Geometry") - ) + Seq("com.snowflake.snowpark_java.types.Geometry")) } // negative test, to make sure this test method works @@ -77,8 +66,7 @@ class IndependentClassSuite extends FunSuite { assertThrows[TestFailedException] { checkDependencies( "target/classes/com/snowflake/snowpark/Session.class", - Seq("com.snowflake.snowpark.Session") - ) + Seq("com.snowflake.snowpark.Session")) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala b/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala index 5e39008b..04228c55 100644 --- a/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala @@ -50,50 +50,40 @@ class JavaUtilsSuite extends FunSuite { test("variant array to string array") { assert( variantArrayToStringArray(Array(new com.snowflake.snowpark.types.Variant(1))) - .isInstanceOf[Array[String]] - ) + .isInstanceOf[Array[String]]) assert( variantArrayToStringArray(Array(new com.snowflake.snowpark_java.types.Variant(1))) - .isInstanceOf[Array[String]] - ) + .isInstanceOf[Array[String]]) } test("string array to variant array") { assert( stringArrayToVariantArray(Array("1")) - .isInstanceOf[Array[com.snowflake.snowpark.types.Variant]] - ) + .isInstanceOf[Array[com.snowflake.snowpark.types.Variant]]) assert( stringArrayToJavaVariantArray(Array("1")) - .isInstanceOf[Array[com.snowflake.snowpark_java.types.Variant]] - ) + .isInstanceOf[Array[com.snowflake.snowpark_java.types.Variant]]) } test("string map to variant map") { assert( stringMapToVariantMap(new util.HashMap[String, String]()) - .isInstanceOf[scala.collection.mutable.Map[String, com.snowflake.snowpark.types.Variant]] - ) + .isInstanceOf[scala.collection.mutable.Map[String, com.snowflake.snowpark.types.Variant]]) assert( stringMapToJavaVariantMap(new util.HashMap[String, String]()) - .isInstanceOf[util.Map[String, com.snowflake.snowpark_java.types.Variant]] - ) + .isInstanceOf[util.Map[String, com.snowflake.snowpark_java.types.Variant]]) } test("variant map to string map") { assert( variantMapToStringMap( - collection.mutable.Map.empty[String, com.snowflake.snowpark.types.Variant] - ) - .isInstanceOf[util.Map[String, String]] - ) + collection.mutable.Map.empty[String, com.snowflake.snowpark.types.Variant]) + .isInstanceOf[util.Map[String, String]]) assert( javaVariantMapToStringMap( - new util.HashMap[String, com.snowflake.snowpark_java.types.Variant]() - ) - .isInstanceOf[util.Map[String, String]] - ) + new util.HashMap[String, com.snowflake.snowpark_java.types.Variant]()) + .isInstanceOf[util.Map[String, String]]) } test("variant to string array") { @@ -101,8 +91,7 @@ class JavaUtilsSuite extends FunSuite { assert(variantToStringArray(new Variant(Array("a", "b"))).sameElements(Array("a", "b"))) assert( variantToStringArray(new Variant(Array(new Variant("a"), new Variant("b")))) - .sameElements(Array("a", "b")) - ) + .sameElements(Array("a", "b"))) } test("java variant to string array") { @@ -110,8 +99,7 @@ class JavaUtilsSuite extends FunSuite { assert(variantToStringArray(new JavaVariant(Array("a", "b"))).sameElements(Array("a", "b"))) assert( variantToStringArray(new JavaVariant(Array(new JavaVariant("a"), new JavaVariant("b")))) - .sameElements(Array("a", "b")) - ) + .sameElements(Array("a", "b"))) } test("variant to string map") { diff --git a/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala index 6a64f2af..c1afe2cb 100644 --- a/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala @@ -26,8 +26,7 @@ class LargeDataFrameSuite extends TestData { (col("id") + 6).as("g"), (col("id") + 7).as("h"), (col("id") + 8).as("i"), - (col("id") + 9).as("j") - ) + (col("id") + 9).as("j")) .cacheResult() val t1 = System.currentTimeMillis() df.collect() @@ -69,8 +68,7 @@ class LargeDataFrameSuite extends TestData { .collect() (0 until (result.length - 1)).foreach(index => - assert(result(index).getInt(0) < result(index + 1).getInt(0)) - ) + assert(result(index).getInt(0) < result(index + 1).getInt(0))) } test("createDataFrame for large values: basic types") { @@ -88,9 +86,7 @@ class LargeDataFrameSuite extends TestData { StructField("boolean", BooleanType), StructField("binary", BinaryType), StructField("timestamp", TimestampType), - StructField("date", DateType) - ) - ) + StructField("date", DateType))) val schemaString = """root | |--ID: Long (nullable = true) @@ -126,14 +122,11 @@ class LargeDataFrameSuite extends TestData { true, Array(1.toByte, 2.toByte), new Timestamp(timestamp - 100), - new Date(timestamp - 100) - ) - ) + new Date(timestamp - 100))) } // Add one null values largeData.append( - Row(1025, null, null, null, null, null, null, null, null, null, null, null, null) - ) + Row(1025, null, null, null, null, null, null, null, null, null, null, null, null)) val result = session.createDataFrame(largeData, schema) // byte, short, int, long are converted to long @@ -159,8 +152,7 @@ class LargeDataFrameSuite extends TestData { """root | |--ID: Long (nullable = true) | |--TIME: Time (nullable = true) - |""".stripMargin - ) + |""".stripMargin) val expected = new ArrayBuffer[Row]() val snowflakeTime = session.sql("select '11:12:13' :: Time").collect()(0).getTime(0) @@ -180,9 +172,7 @@ class LargeDataFrameSuite extends TestData { StructField("map", MapType(null, null)), StructField("variant", VariantType), StructField("geography", GeographyType), - StructField("geometry", GeometryType) - ) - ) + StructField("geometry", GeometryType))) val rowCount = 350 val largeData = new ArrayBuffer[Row]() @@ -194,9 +184,7 @@ class LargeDataFrameSuite extends TestData { Map("'" -> 1), new Variant(1), Geography.fromGeoJSON("POINT(30 10)"), - Geometry.fromGeoJSON("POINT(20 40)") - ) - ) + Geometry.fromGeoJSON("POINT(20 40)"))) } largeData.append(Row(rowCount, null, null, null, null, null)) @@ -210,8 +198,7 @@ class LargeDataFrameSuite extends TestData { | |--VARIANT: Variant (nullable = true) | |--GEOGRAPHY: Geography (nullable = true) | |--GEOMETRY: Geometry (nullable = true) - |""".stripMargin - ) + |""".stripMargin) val expected = new ArrayBuffer[Row]() for (i <- 0 until rowCount) { @@ -234,9 +221,7 @@ class LargeDataFrameSuite extends TestData { | 4.000000000000000e+01 | ], | "type": "Point" - |}""".stripMargin) - ) - ) + |}""".stripMargin))) } expected.append(Row(rowCount, null, null, null, null, null)) checkAnswer(df.sort(col("id")), expected, sort = false) @@ -247,15 +232,12 @@ class LargeDataFrameSuite extends TestData { Seq( StructField("id", LongType), StructField("array", ArrayType(null)), - StructField("map", MapType(null, null)) - ) - ) + StructField("map", MapType(null, null)))) val largeData = new ArrayBuffer[Row]() val rowCount = 350 for (i <- 0 until rowCount) { largeData.append( - Row(i.toLong, Array(new Variant(1), new Variant("\"'")), Map("a" -> new Variant("\"'"))) - ) + Row(i.toLong, Array(new Variant(1), new Variant("\"'")), Map("a" -> new Variant("\"'")))) } largeData.append(Row(rowCount, null, null)) val df = session.createDataFrame(largeData, schema) @@ -272,9 +254,7 @@ class LargeDataFrameSuite extends TestData { Seq( StructField("id", LongType), StructField("array", ArrayType(null)), - StructField("map", MapType(null, null)) - ) - ) + StructField("map", MapType(null, null)))) val largeData = new ArrayBuffer[Row]() val rowCount = 350 for (i <- 0 until rowCount) { @@ -283,14 +263,10 @@ class LargeDataFrameSuite extends TestData { i.toLong, Array( Geography.fromGeoJSON("point(30 10)"), - Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}") - ), + Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}")), Map( "a" -> Geography.fromGeoJSON("point(30 10)"), - "b" -> Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[300,100]}") - ) - ) - ) + "b" -> Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[300,100]}")))) } largeData.append(Row(rowCount, null, null)) val df = session.createDataFrame(largeData, schema) @@ -302,9 +278,7 @@ class LargeDataFrameSuite extends TestData { "[\n \"point(30 10)\",\n {\n \"coordinates\": [\n" + " 30,\n 10\n ],\n \"type\": \"Point\"\n }\n]", "{\n \"a\": \"point(30 10)\",\n \"b\": {\n \"coordinates\": [\n" + - " 300,\n 100\n ],\n \"type\": \"Point\"\n }\n}" - ) - ) + " 300,\n 100\n ],\n \"type\": \"Point\"\n }\n}")) } expected.append(Row(rowCount, null, null)) checkAnswer(df.sort(col("id")), expected, sort = false) @@ -327,8 +301,7 @@ class LargeDataFrameSuite extends TestData { c.as("c6"), c.as("c7"), c.as("c8"), - c.as("c9") - ) + c.as("c9")) val rows = df.collect() assert(rows.length == 10000) assert(rows.last.getLong(0) == 9999) diff --git a/src/test/scala/com/snowflake/snowpark_test/LiteralSuite.scala b/src/test/scala/com/snowflake/snowpark_test/LiteralSuite.scala index 4aefd807..a7b07c2c 100644 --- a/src/test/scala/com/snowflake/snowpark_test/LiteralSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/LiteralSuite.scala @@ -37,8 +37,7 @@ class LiteralSuite extends TestData { | |--LONG: Long (nullable = false) | |--FLOAT: Double (nullable = false) | |--DOUBLE: Double (nullable = false) - |""".stripMargin - ) + |""".stripMargin) df.show() // scalastyle:off @@ -50,8 +49,7 @@ class LiteralSuite extends TestData { ||0 |NULL |string |C |true |10 |11 |12 |13 |14.0 |15.0 | ||1 |NULL |string |C |true |10 |11 |12 |13 |14.0 |15.0 | |----------------------------------------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } @@ -83,8 +81,7 @@ class LiteralSuite extends TestData { | |--LONG: Long (nullable = false) | |--FLOAT: Double (nullable = false) | |--DOUBLE: Double (nullable = false) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:off assert( @@ -95,8 +92,7 @@ class LiteralSuite extends TestData { ||0 |NULL |string |C |true |10 |11 |12 |13 |14.0 |15.0 | ||1 |NULL |string |C |true |10 |11 |12 |13 |14.0 |15.0 | |----------------------------------------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } @@ -128,8 +124,7 @@ class LiteralSuite extends TestData { | |--LONG: Long (nullable = false) | |--FLOAT: Double (nullable = false) | |--DOUBLE: Double (nullable = false) - |""".stripMargin - ) + |""".stripMargin) // scalastyle:off assert( @@ -140,8 +135,7 @@ class LiteralSuite extends TestData { ||0 |NULL |string |C |true |10 |11 |12 |13 |14.0 |15.0 | ||1 |NULL |string |C |true |10 |11 |12 |13 |14.0 |15.0 | |----------------------------------------------------------------------------------------------------- - |""".stripMargin - ) + |""".stripMargin) // scalastyle:on } @@ -157,8 +151,7 @@ class LiteralSuite extends TestData { | |--ID: Long (nullable = false) | |--SCALA: Binary (nullable = false) | |--JAVA: Binary (nullable = false) - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df, 10) == @@ -168,8 +161,7 @@ class LiteralSuite extends TestData { ||0 |'616263' |'656667' | ||1 |'616263' |'656667' | |------------------------------ - |""".stripMargin - ) + |""".stripMargin) } test("Literal TimeStamp and Instant") { @@ -190,8 +182,7 @@ class LiteralSuite extends TestData { | |--ID: Long (nullable = false) | |--TIMESTAMP: Timestamp (nullable = false) | |--INSTANT: Timestamp (nullable = false) - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df, 10) == @@ -201,8 +192,7 @@ class LiteralSuite extends TestData { ||0 |2018-10-11 12:13:14.123 |2020-10-11 12:13:14.123 | ||1 |2018-10-11 12:13:14.123 |2020-10-11 12:13:14.123 | |------------------------------------------------------------ - |""".stripMargin - ) + |""".stripMargin) } finally { TimeZone.setDefault(oldTimeZone) } @@ -223,8 +213,7 @@ class LiteralSuite extends TestData { | |--ID: Long (nullable = false) | |--LOCAL_DATE: Date (nullable = false) | |--DATE: Date (nullable = false) - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df, 10) == @@ -234,8 +223,7 @@ class LiteralSuite extends TestData { ||0 |2020-10-11 |2018-10-11 | ||1 |2020-10-11 |2018-10-11 | |------------------------------------ - |""".stripMargin - ) + |""".stripMargin) } finally { TimeZone.setDefault(oldTimeZone) } @@ -255,8 +243,7 @@ class LiteralSuite extends TestData { | |--ID: Long (nullable = false) | |--NULL: String (nullable = true) | |--LITERAL: Long (nullable = false) - |""".stripMargin - ) + |""".stripMargin) assert( getShowString(df, 10) == @@ -266,7 +253,6 @@ class LiteralSuite extends TestData { ||0 |NULL |123 | ||1 |NULL |123 | |----------------------------- - |""".stripMargin - ) + |""".stripMargin) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala b/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala index 693e5a35..90bfe4d9 100644 --- a/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala @@ -261,8 +261,10 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val stageName = randomName() val tableName = randomName() val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq( + StructField("a", IntegerType), + StructField("b", StringType), + StructField("c", DoubleType))) try { createStage(stageName) uploadFileToStage(stageName, testFileCsv, compress = false) @@ -291,8 +293,10 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { val tableName = randomName() val className = "snow.snowpark.CopyableDataFrameAsyncActor" val userSchema: StructType = StructType( - Seq(StructField("a", IntegerType), StructField("b", StringType), StructField("c", DoubleType)) - ) + Seq( + StructField("a", IntegerType), + StructField("b", StringType), + StructField("c", DoubleType))) try { createStage(stageName) uploadFileToStage(stageName, testFileCsv, compress = false) @@ -462,7 +466,6 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { funcName, "OpenTelemetrySuite.scala", file.getLineNumber - 1, - s"DataFrame.$funcName" - ) + s"DataFrame.$funcName") } } diff --git a/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala index 83d61201..116bb416 100644 --- a/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala @@ -200,8 +200,7 @@ class PermanentUDFSuite extends TestData { | return x; | } |}'""".stripMargin, - session - ) + session) // Before the UDF registration, there is no file in @$stageName1/$permFuncName/ assert(session.sql(s"ls @$stageName1/$permFuncName/").collect().length == 0) // The same name UDF registration will fail. @@ -282,13 +281,11 @@ class PermanentUDFSuite extends TestData { checkAnswer( df.select(callUDF(funcName, df("a1"), df("a2"))), Seq(Row(3), Row(23)), - sort = false - ) + sort = false) checkAnswer( newDf.select(callUDF(funcName, newDf("a1"), newDf("a2"))), Seq(Row(3), Row(23)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT)", session) } @@ -308,13 +305,11 @@ class PermanentUDFSuite extends TestData { checkAnswer( df.select(callUDF(funcName, df("a1"), df("a2"), df("a3"))), Seq(Row(6), Row(36)), - sort = false - ) + sort = false) checkAnswer( newDf.select(callUDF(funcName, newDf("a1"), newDf("a2"), newDf("a3"))), Seq(Row(6), Row(36)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT)", session) } @@ -337,13 +332,11 @@ class PermanentUDFSuite extends TestData { checkAnswer( df.select(callUDF(funcName, df("a1"), df("a2"), df("a3"), df("a4"))), Seq(Row(10), Row(50)), - sort = false - ) + sort = false) checkAnswer( newDf.select(callUDF(funcName, newDf("a1"), newDf("a2"), newDf("a3"), newDf("a4"))), Seq(Row(10), Row(50)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT)", session) } @@ -366,15 +359,12 @@ class PermanentUDFSuite extends TestData { checkAnswer( df.select(callUDF(funcName, df("a1"), df("a2"), df("a3"), df("a4"), df("a5"))), Seq(Row(15), Row(65)), - sort = false - ) + sort = false) checkAnswer( newDf.select( - callUDF(funcName, newDf("a1"), newDf("a2"), newDf("a3"), newDf("a4"), newDf("a5")) - ), + callUDF(funcName, newDf("a1"), newDf("a2"), newDf("a3"), newDf("a4"), newDf("a5"))), Seq(Row(15), Row(65)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT)", session) } @@ -398,8 +388,7 @@ class PermanentUDFSuite extends TestData { checkAnswer( df.select(callUDF(funcName, df("a1"), df("a2"), df("a3"), df("a4"), df("a5"), df("a6"))), Seq(Row(21), Row(81)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -409,12 +398,9 @@ class PermanentUDFSuite extends TestData { newDf("a3"), newDf("a4"), newDf("a5"), - newDf("a6") - ) - ), + newDf("a6"))), Seq(Row(21), Row(81)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT,INT)", session) } @@ -437,11 +423,9 @@ class PermanentUDFSuite extends TestData { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( df.select( - callUDF(funcName, df("a1"), df("a2"), df("a3"), df("a4"), df("a5"), df("a6"), df("a7")) - ), + callUDF(funcName, df("a1"), df("a2"), df("a3"), df("a4"), df("a5"), df("a6"), df("a7"))), Seq(Row(28), Row(98)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -452,12 +436,9 @@ class PermanentUDFSuite extends TestData { newDf("a4"), newDf("a5"), newDf("a6"), - newDf("a7") - ) - ), + newDf("a7"))), Seq(Row(28), Row(98)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT)", session) } @@ -489,12 +470,9 @@ class PermanentUDFSuite extends TestData { df("a5"), df("a6"), df("a7"), - df("a8") - ) - ), + df("a8"))), Seq(Row(36), Row(116)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -506,12 +484,9 @@ class PermanentUDFSuite extends TestData { newDf("a5"), newDf("a6"), newDf("a7"), - newDf("a8") - ) - ), + newDf("a8"))), Seq(Row(36), Row(116)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT)", session) } @@ -545,12 +520,9 @@ class PermanentUDFSuite extends TestData { df("a6"), df("a7"), df("a8"), - df("a9") - ) - ), + df("a9"))), Seq(Row(45), Row(135)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -563,12 +535,9 @@ class PermanentUDFSuite extends TestData { newDf("a6"), newDf("a7"), newDf("a8"), - newDf("a9") - ) - ), + newDf("a9"))), Seq(Row(45), Row(135)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT)", session) } @@ -580,8 +549,7 @@ class PermanentUDFSuite extends TestData { import functions.callUDF val df = session .createDataFrame( - Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) - ) + Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10")) val func = (a1: Int, a2: Int, a3: Int, a4: Int, a5: Int, a6: Int, a7: Int, a8: Int, a9: Int, a10: Int) => @@ -589,8 +557,7 @@ class PermanentUDFSuite extends TestData { val funcName: String = randomName() val newDf = newSession .createDataFrame( - Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) - ) + Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10")) try { session.udf.registerPermanent(funcName, func, stageName) @@ -607,12 +574,9 @@ class PermanentUDFSuite extends TestData { df("a7"), df("a8"), df("a9"), - df("a10") - ) - ), + df("a10"))), Seq(Row(55), Row(155)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -626,12 +590,9 @@ class PermanentUDFSuite extends TestData { newDf("a7"), newDf("a8"), newDf("a9"), - newDf("a10") - ) - ), + newDf("a10"))), Seq(Row(55), Row(155)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", session) } @@ -643,8 +604,7 @@ class PermanentUDFSuite extends TestData { import functions.callUDF val df = session .createDataFrame( - Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)) - ) + Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11")) val func = ( a1: Int, @@ -657,13 +617,11 @@ class PermanentUDFSuite extends TestData { a8: Int, a9: Int, a10: Int, - a11: Int - ) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a11: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 val funcName: String = randomName() val newDf = newSession .createDataFrame( - Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)) - ) + Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11")) try { session.udf.registerPermanent(funcName, func, stageName) @@ -681,12 +639,9 @@ class PermanentUDFSuite extends TestData { df("a8"), df("a9"), df("a10"), - df("a11") - ) - ), + df("a11"))), Seq(Row(66), Row(176)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -701,12 +656,9 @@ class PermanentUDFSuite extends TestData { newDf("a8"), newDf("a9"), newDf("a10"), - newDf("a11") - ) - ), + newDf("a11"))), Seq(Row(66), Row(176)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", session) } @@ -720,9 +672,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11", "a12")) val func = ( a1: Int, @@ -736,16 +686,13 @@ class PermanentUDFSuite extends TestData { a9: Int, a10: Int, a11: Int, - a12: Int - ) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a12: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11", "a12")) try { session.udf.registerPermanent(funcName, func, stageName) @@ -764,12 +711,9 @@ class PermanentUDFSuite extends TestData { df("a9"), df("a10"), df("a11"), - df("a12") - ) - ), + df("a12"))), Seq(Row(78), Row(198)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -785,12 +729,9 @@ class PermanentUDFSuite extends TestData { newDf("a9"), newDf("a10"), newDf("a11"), - newDf("a12") - ) - ), + newDf("a12"))), Seq(Row(78), Row(198)), - sort = false - ) + sort = false) } finally { runQuery(s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", session) } @@ -804,9 +745,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11", "a12", "a13")) val func = ( a1: Int, @@ -821,16 +760,13 @@ class PermanentUDFSuite extends TestData { a10: Int, a11: Int, a12: Int, - a13: Int - ) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a13: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23))) .toDF(Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11", "a12", "a13")) try { session.udf.registerPermanent(funcName, func, stageName) @@ -850,12 +786,9 @@ class PermanentUDFSuite extends TestData { df("a10"), df("a11"), df("a12"), - df("a13") - ) - ), + df("a13"))), Seq(Row(91), Row(221)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -872,17 +805,13 @@ class PermanentUDFSuite extends TestData { newDf("a10"), newDf("a11"), newDf("a12"), - newDf("a13") - ) - ), + newDf("a13"))), Seq(Row(91), Row(221)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -894,12 +823,23 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24))) .toDF( - Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11", "a12", "a13", "a14") - ) + Seq( + "a1", + "a2", + "a3", + "a4", + "a5", + "a6", + "a7", + "a8", + "a9", + "a10", + "a11", + "a12", + "a13", + "a14")) val func = ( a1: Int, a2: Int, @@ -914,19 +854,29 @@ class PermanentUDFSuite extends TestData { a11: Int, a12: Int, a13: Int, - a14: Int - ) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a14: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24))) .toDF( - Seq("a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "a11", "a12", "a13", "a14") - ) + Seq( + "a1", + "a2", + "a3", + "a4", + "a5", + "a6", + "a7", + "a8", + "a9", + "a10", + "a11", + "a12", + "a13", + "a14")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -946,12 +896,9 @@ class PermanentUDFSuite extends TestData { df("a11"), df("a12"), df("a13"), - df("a14") - ) - ), + df("a14"))), Seq(Row(105), Row(245)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -969,17 +916,13 @@ class PermanentUDFSuite extends TestData { newDf("a11"), newDf("a12"), newDf("a13"), - newDf("a14") - ) - ), + newDf("a14"))), Seq(Row(105), Row(245)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -991,9 +934,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25))) .toDF( Seq( "a1", @@ -1010,9 +951,7 @@ class PermanentUDFSuite extends TestData { "a12", "a13", "a14", - "a15" - ) - ) + "a15")) val func = ( a1: Int, a2: Int, @@ -1028,16 +967,13 @@ class PermanentUDFSuite extends TestData { a12: Int, a13: Int, a14: Int, - a15: Int - ) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a15: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25))) .toDF( Seq( "a1", @@ -1054,9 +990,7 @@ class PermanentUDFSuite extends TestData { "a12", "a13", "a14", - "a15" - ) - ) + "a15")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -1077,12 +1011,9 @@ class PermanentUDFSuite extends TestData { df("a12"), df("a13"), df("a14"), - df("a15") - ) - ), + df("a15"))), Seq(Row(120), Row(270)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -1101,17 +1032,13 @@ class PermanentUDFSuite extends TestData { newDf("a12"), newDf("a13"), newDf("a14"), - newDf("a15") - ) - ), + newDf("a15"))), Seq(Row(120), Row(270)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -1123,9 +1050,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26))) .toDF( Seq( "a1", @@ -1143,9 +1068,7 @@ class PermanentUDFSuite extends TestData { "a13", "a14", "a15", - "a16" - ) - ) + "a16")) val func = ( a1: Int, a2: Int, @@ -1162,16 +1085,14 @@ class PermanentUDFSuite extends TestData { a13: Int, a14: Int, a15: Int, - a16: Int - ) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a16: Int) => + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26))) .toDF( Seq( "a1", @@ -1189,9 +1110,7 @@ class PermanentUDFSuite extends TestData { "a13", "a14", "a15", - "a16" - ) - ) + "a16")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -1213,12 +1132,9 @@ class PermanentUDFSuite extends TestData { df("a13"), df("a14"), df("a15"), - df("a16") - ) - ), + df("a16"))), Seq(Row(136), Row(296)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -1238,17 +1154,13 @@ class PermanentUDFSuite extends TestData { newDf("a13"), newDf("a14"), newDf("a15"), - newDf("a16") - ) - ), + newDf("a16"))), Seq(Row(136), Row(296)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -1260,9 +1172,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27))) .toDF( Seq( "a1", @@ -1281,9 +1191,7 @@ class PermanentUDFSuite extends TestData { "a14", "a15", "a16", - "a17" - ) - ) + "a17")) val func = ( a1: Int, a2: Int, @@ -1301,16 +1209,14 @@ class PermanentUDFSuite extends TestData { a14: Int, a15: Int, a16: Int, - a17: Int - ) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a17: Int) => + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27))) .toDF( Seq( "a1", @@ -1329,9 +1235,7 @@ class PermanentUDFSuite extends TestData { "a14", "a15", "a16", - "a17" - ) - ) + "a17")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -1354,12 +1258,9 @@ class PermanentUDFSuite extends TestData { df("a14"), df("a15"), df("a16"), - df("a17") - ) - ), + df("a17"))), Seq(Row(153), Row(323)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -1380,17 +1281,13 @@ class PermanentUDFSuite extends TestData { newDf("a14"), newDf("a15"), newDf("a16"), - newDf("a17") - ) - ), + newDf("a17"))), Seq(Row(153), Row(323)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -1402,9 +1299,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28))) .toDF( Seq( "a1", @@ -1424,9 +1319,7 @@ class PermanentUDFSuite extends TestData { "a15", "a16", "a17", - "a18" - ) - ) + "a18")) val func = ( a1: Int, a2: Int, @@ -1445,17 +1338,14 @@ class PermanentUDFSuite extends TestData { a15: Int, a16: Int, a17: Int, - a18: Int - ) => + a18: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28))) .toDF( Seq( "a1", @@ -1475,9 +1365,7 @@ class PermanentUDFSuite extends TestData { "a15", "a16", "a17", - "a18" - ) - ) + "a18")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -1501,12 +1389,9 @@ class PermanentUDFSuite extends TestData { df("a15"), df("a16"), df("a17"), - df("a18") - ) - ), + df("a18"))), Seq(Row(171), Row(351)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -1528,17 +1413,13 @@ class PermanentUDFSuite extends TestData { newDf("a15"), newDf("a16"), newDf("a17"), - newDf("a18") - ) - ), + newDf("a18"))), Seq(Row(171), Row(351)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -1550,9 +1431,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29))) .toDF( Seq( "a1", @@ -1573,9 +1452,7 @@ class PermanentUDFSuite extends TestData { "a16", "a17", "a18", - "a19" - ) - ) + "a19")) val func = ( a1: Int, a2: Int, @@ -1595,17 +1472,14 @@ class PermanentUDFSuite extends TestData { a16: Int, a17: Int, a18: Int, - a19: Int - ) => + a19: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29))) .toDF( Seq( "a1", @@ -1626,9 +1500,7 @@ class PermanentUDFSuite extends TestData { "a16", "a17", "a18", - "a19" - ) - ) + "a19")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -1653,12 +1525,9 @@ class PermanentUDFSuite extends TestData { df("a16"), df("a17"), df("a18"), - df("a19") - ) - ), + df("a19"))), Seq(Row(190), Row(380)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -1681,17 +1550,13 @@ class PermanentUDFSuite extends TestData { newDf("a16"), newDf("a17"), newDf("a18"), - newDf("a19") - ) - ), + newDf("a19"))), Seq(Row(190), Row(380)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -1703,9 +1568,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30))) .toDF( Seq( "a1", @@ -1727,9 +1590,7 @@ class PermanentUDFSuite extends TestData { "a17", "a18", "a19", - "a20" - ) - ) + "a20")) val func = ( a1: Int, a2: Int, @@ -1750,17 +1611,14 @@ class PermanentUDFSuite extends TestData { a17: Int, a18: Int, a19: Int, - a20: Int - ) => + a20: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a20 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30))) .toDF( Seq( "a1", @@ -1782,9 +1640,7 @@ class PermanentUDFSuite extends TestData { "a17", "a18", "a19", - "a20" - ) - ) + "a20")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -1810,12 +1666,9 @@ class PermanentUDFSuite extends TestData { df("a17"), df("a18"), df("a19"), - df("a20") - ) - ), + df("a20"))), Seq(Row(210), Row(410)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -1839,17 +1692,13 @@ class PermanentUDFSuite extends TestData { newDf("a17"), newDf("a18"), newDf("a19"), - newDf("a20") - ) - ), + newDf("a20"))), Seq(Row(210), Row(410)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -1861,9 +1710,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31))) .toDF( Seq( "a1", @@ -1886,9 +1733,7 @@ class PermanentUDFSuite extends TestData { "a18", "a19", "a20", - "a21" - ) - ) + "a21")) val func = ( a1: Int, a2: Int, @@ -1910,17 +1755,14 @@ class PermanentUDFSuite extends TestData { a18: Int, a19: Int, a20: Int, - a21: Int - ) => + a21: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a20 + a21 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31))) .toDF( Seq( "a1", @@ -1943,9 +1785,7 @@ class PermanentUDFSuite extends TestData { "a18", "a19", "a20", - "a21" - ) - ) + "a21")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -1972,12 +1812,9 @@ class PermanentUDFSuite extends TestData { df("a18"), df("a19"), df("a20"), - df("a21") - ) - ), + df("a21"))), Seq(Row(231), Row(441)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -2002,17 +1839,13 @@ class PermanentUDFSuite extends TestData { newDf("a18"), newDf("a19"), newDf("a20"), - newDf("a21") - ) - ), + newDf("a21"))), Seq(Row(231), Row(441)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -2024,9 +1857,7 @@ class PermanentUDFSuite extends TestData { .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32))) .toDF( Seq( "a1", @@ -2050,9 +1881,7 @@ class PermanentUDFSuite extends TestData { "a19", "a20", "a21", - "a22" - ) - ) + "a22")) val func = ( a1: Int, a2: Int, @@ -2075,17 +1904,14 @@ class PermanentUDFSuite extends TestData { a19: Int, a20: Int, a21: Int, - a22: Int - ) => + a22: Int) => a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a20 + a21 + a22 val funcName: String = randomName() val newDf = newSession .createDataFrame( Seq( (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22), - (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32) - ) - ) + (11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32))) .toDF( Seq( "a1", @@ -2109,9 +1935,7 @@ class PermanentUDFSuite extends TestData { "a19", "a20", "a21", - "a22" - ) - ) + "a22")) try { session.udf.registerPermanent(funcName, func, stageName) checkAnswer( @@ -2139,12 +1963,9 @@ class PermanentUDFSuite extends TestData { df("a19"), df("a20"), df("a21"), - df("a22") - ) - ), + df("a22"))), Seq(Row(253), Row(473)), - sort = false - ) + sort = false) checkAnswer( newDf.select( callUDF( @@ -2170,17 +1991,13 @@ class PermanentUDFSuite extends TestData { newDf("a19"), newDf("a20"), newDf("a21"), - newDf("a22") - ) - ), + newDf("a22"))), Seq(Row(253), Row(473)), - sort = false - ) + sort = false) } finally { runQuery( s"drop function $funcName(INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)", - session - ) + session) } // scalastyle:on } @@ -2317,11 +2134,9 @@ class PermanentUDFSuite extends TestData { session.file.put( TestUtils.escapePath( tempDirectory1.getCanonicalPath + - TestUtils.fileSeparator + fileName - ), + TestUtils.fileSeparator + fileName), stageName1, - Map("AUTO_COMPRESS" -> "FALSE") - ) + Map("AUTO_COMPRESS" -> "FALSE")) session.addDependency(s"@$stageName1/" + fileName) val df1 = session.createDataFrame(Seq(fileName)).toDF(Seq("a")) @@ -2350,8 +2165,7 @@ class PermanentUDFSuite extends TestData { 4L, 1.1f, 1.2d, - new java.math.BigDecimal(1.3).setScale(3, RoundingMode.HALF_DOWN) - ) + new java.math.BigDecimal(1.3).setScale(3, RoundingMode.HALF_DOWN)) val func = (a: Short, b: Int, c: Long, d: Float, e: Double, f: java.math.BigDecimal) => s"$a $b $c $d $e $f" @@ -2368,8 +2182,7 @@ class PermanentUDFSuite extends TestData { val df2 = session .range(1) .select( - callBuiltin(funcName, values._1, values._2, values._3, values._4, values._5, values._6) - ) + callBuiltin(funcName, values._1, values._2, values._3, values._4, values._5, values._6)) checkAnswer(df2, Seq(Row("2 3 4 1.1 1.2 1.300000000000000000"))) // test builtin()() val df3 = session @@ -2379,8 +2192,7 @@ class PermanentUDFSuite extends TestData { } finally { runQuery( s"drop function if exists $funcName(SMALLINT,INT,BIGINT,FLOAT,DOUBLE,NUMBER(38,18))", - session - ) + session) } } @@ -2391,8 +2203,7 @@ class PermanentUDFSuite extends TestData { true, Array(61.toByte, 62.toByte), new Timestamp(timestamp - 100), - new Date(timestamp - 100) - ) + new Date(timestamp - 100)) val func = (a: String, b: Boolean, c: Array[Byte], d: Timestamp, e: Date) => s"$a $b 0x${c.map { _.toHexString }.mkString("")} $d $e" diff --git a/src/test/scala/com/snowflake/snowpark_test/RequestTimeoutSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RequestTimeoutSuite.scala index 293b2b60..9783b1eb 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RequestTimeoutSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RequestTimeoutSuite.scala @@ -7,7 +7,6 @@ class RequestTimeoutSuite extends UploadTimeoutSession { // Jar upload timeout is set to 0 second test("Test udf jar upload timeout") { assertThrows[SnowparkClientException]( - mockSession.udf.registerTemporary((a: Int, b: Int) => a == b) - ) + mockSession.udf.registerTemporary((a: Int, b: Int) => a == b)) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala index 9e0e2d16..85a67984 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala @@ -29,8 +29,7 @@ class ResultSchemaSuite extends TestData { .map(row => s"""${row.colName} ${row.sfType}, "${row.colName}" ${row.sfType} not null,""") .reduce((x, y) => x + y) .dropRight(1) - .stripMargin - ) + .stripMargin) createTable( fullTypesTable2, @@ -38,8 +37,7 @@ class ResultSchemaSuite extends TestData { .map(row => s"""${row.colName} ${row.sfType},""") .reduce((x, y) => x + y) .dropRight(1) - .stripMargin - ) + .stripMargin) } override def afterAll: Unit = { @@ -59,8 +57,7 @@ class ResultSchemaSuite extends TestData { test("alter") { verifySchema( "alter session set ABORT_DETACHED_QUERY=false", - session.sql("alter session set ABORT_DETACHED_QUERY=false").schema - ) + session.sql("alter session set ABORT_DETACHED_QUERY=false").schema) } test("list, remove file") { @@ -70,16 +67,14 @@ class ResultSchemaSuite extends TestData { uploadFileToStage(stageName, testFile2Csv, compress = false) verifySchema( s"rm @$stageName/$testFileCsv", - session.sql(s"rm @$stageName/$testFile2Csv").schema - ) + session.sql(s"rm @$stageName/$testFile2Csv").schema) // Re-upload to test remove uploadFileToStage(stageName, testFileCsv, compress = false) uploadFileToStage(stageName, testFile2Csv, compress = false) verifySchema( s"remove @$stageName/$testFileCsv", - session.sql(s"remove @$stageName/$testFile2Csv").schema - ) + session.sql(s"remove @$stageName/$testFile2Csv").schema) } test("select") { @@ -94,8 +89,7 @@ class ResultSchemaSuite extends TestData { val df2 = df1.filter(col("\"int\"") > 0) verifySchema( s"""select string, "int", array, "date" from $fullTypesTable where \"int\" > 0""", - df2.schema - ) + df2.schema) } // ignore it for now since we are modifying the analyzer system. @@ -148,9 +142,7 @@ class ResultSchemaSuite extends TestData { resultMeta.getColumnTypeName(index + 1), resultMeta.getPrecision(index + 1), resultMeta.getScale(index + 1), - resultMeta.isSigned(index + 1) - ) == typeMap(index).tsType - ) + resultMeta.isSigned(index + 1)) == typeMap(index).tsType) assert(tsSchema(index).dataType == typeMap(index).tsType) }) statement.close() diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index f971c044..54aba687 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -57,9 +57,7 @@ class RowSuite extends SNTestBase { Array[Byte](1, 9), Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}"), Geometry.fromGeoJSON( - "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}" - ) - ) + "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}")) assert(row.length == 20) assert(row.isNullAt(0)) @@ -80,20 +78,16 @@ class RowSuite extends SNTestBase { assertThrows[ClassCastException](row.getBinary(6)) assert( row.getGeography(18) == - Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}") - ) + Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[30,10]}")) assertThrows[ClassCastException](row.getBinary(18)) assert(row.getString(18) == "{\"type\":\"Point\",\"coordinates\":[30,10]}") assert( row.getGeometry(19) == Geometry.fromGeoJSON( - "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}" - ) - ) + "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}")) assert( row.getString(19) == - "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}" - ) + "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}") } test("number getters") { @@ -110,9 +104,7 @@ class RowSuite extends SNTestBase { Float.MinValue, Double.MaxValue, Double.MinValue, - "Str" - ) - ) + "Str")) // getByte assert(testRow.getByte(0) == 1.toByte) diff --git a/src/test/scala/com/snowflake/snowpark_test/ScalaGeographySuite.scala b/src/test/scala/com/snowflake/snowpark_test/ScalaGeographySuite.scala index 72ad7a56..a1ce97f3 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ScalaGeographySuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ScalaGeographySuite.scala @@ -19,11 +19,9 @@ class ScalaGeographySuite extends SNTestBase { assert( Geography.fromGeoJSON(testData).hashCode() == - Geography.fromGeoJSON(testData).hashCode() - ) + Geography.fromGeoJSON(testData).hashCode()) assert( Geography.fromGeoJSON(testData).hashCode() != - Geography.fromGeoJSON(testData2).hashCode() - ) + Geography.fromGeoJSON(testData2).hashCode()) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala index 30ace108..35e8c572 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala @@ -19,8 +19,7 @@ class ScalaVariantSuite extends FunSuite { assert( new Variant(Array(1.toByte, 2.toByte, 3.toByte)) .asBinary() - .sameElements(Array(1.toByte, 2.toByte, 3.toByte)) - ) + .sameElements(Array(1.toByte, 2.toByte, 3.toByte))) val arr = new Variant(Array(true, 1)).asArray() assert(arr(0).asBoolean()) assert(arr(1).asInt() == 1) @@ -28,9 +27,7 @@ class ScalaVariantSuite extends FunSuite { assert(new Variant(Date.valueOf("2020-10-10")).asDate() == Date.valueOf("2020-10-10")) assert( new Variant(Timestamp.valueOf("2020-10-10 01:02:03")).asTimestamp() == Timestamp.valueOf( - "2020-10-10 01:02:03" - ) - ) + "2020-10-10 01:02:03")) val seq = new Variant(Seq(1, 2, 3)).asSeq() assert(seq.head.asInt() == 1) assert(seq(1).asInt() == 2) @@ -284,13 +281,11 @@ class ScalaVariantSuite extends FunSuite { assert(new Variant(map2).asJsonString().equals("{\"a\":1,\"b\":\"a\"}")) val map3 = Map( "a" -> Geography.fromGeoJSON("Point(10 10)"), - "b" -> Geography.fromGeoJSON("Point(20 20)") - ) + "b" -> Geography.fromGeoJSON("Point(20 20)")) assert( new Variant(map3) .asJsonString() - .equals("{\"a\":\"Point(10 10)\"," + "\"b\":\"Point(20 20)\"}") - ); + .equals("{\"a\":\"Point(10 10)\"," + "\"b\":\"Point(20 20)\"}")); } test("negative test for conversion") { diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala index dfd9cbc6..6f07bb9a 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala @@ -129,8 +129,7 @@ class SessionSuite extends SNTestBase { val currentClient = getShowString(session.sql("select current_client()"), 10) assert(currentClient contains "Snowpark") assert( - currentClient contains SnowparkSFConnectionHandler.extractValidVersionNumber(Utils.Version) - ) + currentClient contains SnowparkSFConnectionHandler.extractValidVersionNumber(Utils.Version)) } test("negative test to input invalid table name for Session.table()") { @@ -147,8 +146,7 @@ class SessionSuite extends SNTestBase { checkAnswer( Seq(None, Some(Array(1, 2))).toDF("arr"), Seq(Row(null), Row("[\n 1,\n 2\n]")), - sort = false - ) + sort = false) } test("create dataframe from array") { @@ -283,9 +281,7 @@ class SessionSuite extends SNTestBase { assert( exception.message.startsWith( "Error Code: 0426, Error message: The given query tag must be a valid JSON string. " + - "Ensure it's correctly formatted as JSON." - ) - ) + "Ensure it's correctly formatted as JSON.")) } test("updateQueryTag when the query tag of the current session is not a valid JSON") { @@ -297,9 +293,7 @@ class SessionSuite extends SNTestBase { assert( exception.message.startsWith( "Error Code: 0427, Error message: The query tag of the current session must be a valid " + - "JSON string. Current query tag: tag1" - ) - ) + "JSON string. Current query tag: tag1")) } test("updateQueryTag when the query tag of the current session is set with an ALTER SESSION") { @@ -341,13 +335,11 @@ class SessionSuite extends SNTestBase { test("generator") { checkAnswer( session.generator(3, Seq(lit(1).as("a"), lit(2).as("b"))), - Seq(Row(1, 2), Row(1, 2), Row(1, 2)) - ) + Seq(Row(1, 2), Row(1, 2), Row(1, 2))) checkAnswer( session.generator(3, lit(1).as("a"), lit(2).as("b")), - Seq(Row(1, 2), Row(1, 2), Row(1, 2)) - ) + Seq(Row(1, 2), Row(1, 2), Row(1, 2))) val msg = intercept[SnowparkClientException](session.generator(3, Seq.empty)) assert(msg.message.contains("The column list of generator function can not be empty")) @@ -373,8 +365,7 @@ class SessionSuite extends SNTestBase { jsonSessionInfo .get("jdbc.session.id") .asText() - .equals(session.jdbcConnection.asInstanceOf[SnowflakeConnectionV1].getSessionID) - ) + .equals(session.jdbcConnection.asInstanceOf[SnowflakeConnectionV1].getSessionID)) assert(jsonSessionInfo.get("os.name").asText().equals(Utils.OSName)) assert(jsonSessionInfo.get("jdbc.version").asText().equals(SnowflakeDriver.implementVersion)) assert(jsonSessionInfo.get("snowpark.library").asText().nonEmpty) diff --git a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala index e5c18e42..a0ae3be0 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala @@ -32,8 +32,7 @@ trait SqlSuite extends SNTestBase { |return 'Done' |$$$$ |""".stripMargin, - session - ) + session) } override def afterAll: Unit = { @@ -243,8 +242,7 @@ class LazySqlSuite extends SqlSuite with LazySession { // test for insertion to a non-existing table session.sql(s"insert into $tableName2 values(1),(2),(3)") // no error assertThrows[SnowflakeSQLException]( - session.sql(s"insert into $tableName2 values(1),(2),(3)").collect() - ) + session.sql(s"insert into $tableName2 values(1),(2),(3)").collect()) // test for insertion with wrong type of data, throws exception when collect val insert2 = session.sql(s"insert into $tableName1 values(1.4),('test')") diff --git a/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala b/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala index 1daa2109..4f2f6da4 100644 --- a/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala @@ -96,8 +96,7 @@ class StoredProcedureSuite extends SNTestBase { session.sql(query).show() checkAnswer( session.storedProcedure(spName, "a", 1, 1.1, true), - Seq(Row("input: a, 1, 1.1, true")) - ) + Seq(Row("input: a, 1, 1.1, true"))) } finally { session.sql(s"drop procedure if exists $spName (STRING, INT, FLOAT, BOOLEAN)").show() } @@ -131,12 +130,10 @@ class StoredProcedureSuite extends SNTestBase { val num = num1 + num2 + num3 val float = (num4 + num5).ceil s"$num, $float, $bool" - } - ) + }) checkAnswer( session.storedProcedure(sp, 1, 2L, 3.toShort, 4.4f, 5.5, false), - Seq(Row(s"6, 10.0, false")) - ) + Seq(Row(s"6, 10.0, false"))) } test("decimal input") { @@ -170,8 +167,7 @@ class StoredProcedureSuite extends SNTestBase { spName, (_: Session) => s"SUCCESS", stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp), Seq(Row("SUCCESS"))) checkAnswer(session.storedProcedure(spName), Seq(Row("SUCCESS"))) } finally { @@ -189,8 +185,7 @@ class StoredProcedureSuite extends SNTestBase { spName, (_: Session, num1: Int) => num1 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1), Seq(Row(101))) checkAnswer(session.storedProcedure(spName, 1), Seq(Row(101))) } finally { @@ -238,8 +233,7 @@ class StoredProcedureSuite extends SNTestBase { spName, (_: Session, num1: Int, num2: Int) => num1 + num2 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2), Seq(Row(103))) checkAnswer(session.storedProcedure(spName, 1, 2), Seq(Row(103))) } finally { @@ -257,8 +251,7 @@ class StoredProcedureSuite extends SNTestBase { spName, (_: Session, num1: Int, num2: Int, num3: Int) => num1 + num2 + num3 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3), Seq(Row(106))) checkAnswer(session.storedProcedure(spName, 1, 2, 3), Seq(Row(106))) } finally { @@ -276,8 +269,7 @@ class StoredProcedureSuite extends SNTestBase { spName, (_: Session, num1: Int, num2: Int, num3: Int, num4: Int) => num1 + num2 + num3 + num4 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4), Seq(Row(110))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4), Seq(Row(110))) } finally { @@ -296,8 +288,7 @@ class StoredProcedureSuite extends SNTestBase { (_: Session, num1: Int, num2: Int, num3: Int, num4: Int, num5: Int) => num1 + num2 + num3 + num4 + num5 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5), Seq(Row(115))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4, 5), Seq(Row(115))) } finally { @@ -316,8 +307,7 @@ class StoredProcedureSuite extends SNTestBase { (_: Session, num1: Int, num2: Int, num3: Int, num4: Int, num5: Int, num6: Int) => num1 + num2 + num3 + num4 + num5 + num6 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6), Seq(Row(121))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4, 5, 6), Seq(Row(121))) } finally { @@ -336,8 +326,7 @@ class StoredProcedureSuite extends SNTestBase { (_: Session, num1: Int, num2: Int, num3: Int, num4: Int, num5: Int, num6: Int, num7: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7), Seq(Row(128))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7), Seq(Row(128))) } finally { @@ -362,11 +351,9 @@ class StoredProcedureSuite extends SNTestBase { num5: Int, num6: Int, num7: Int, - num8: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + 100, + num8: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8), Seq(Row(136))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8), Seq(Row(136))) } finally { @@ -392,11 +379,9 @@ class StoredProcedureSuite extends SNTestBase { num6: Int, num7: Int, num8: Int, - num9: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + 100, + num9: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9), Seq(Row(145))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9), Seq(Row(145))) } finally { @@ -425,11 +410,10 @@ class StoredProcedureSuite extends SNTestBase { num7: Int, num8: Int, num9: Int, - num10: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + 100, + num10: Int) => + num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Seq(Row(155))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Seq(Row(155))) } finally { @@ -459,11 +443,10 @@ class StoredProcedureSuite extends SNTestBase { num8: Int, num9: Int, num10: Int, - num11: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + 100, + num11: Int) => + num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), Seq(Row(166))) checkAnswer(session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), Seq(Row(166))) } finally { @@ -494,18 +477,15 @@ class StoredProcedureSuite extends SNTestBase { num9: Int, num10: Int, num11: Int, - num12: Int - ) => + num12: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), Seq(Row(178))) checkAnswer( session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), - Seq(Row(178)) - ) + Seq(Row(178))) } finally { dropStage(stageName) session @@ -535,27 +515,22 @@ class StoredProcedureSuite extends SNTestBase { num10: Int, num11: Int, num12: Int, - num13: Int - ) => + num13: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), - Seq(Row(191)) - ) + Seq(Row(191))) checkAnswer( session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), - Seq(Row(191)) - ) + Seq(Row(191))) } finally { dropStage(stageName) session .sql( - s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)" - ) + s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -582,28 +557,23 @@ class StoredProcedureSuite extends SNTestBase { num11: Int, num12: Int, num13: Int, - num14: Int - ) => + num14: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), - Seq(Row(205)) - ) + Seq(Row(205))) checkAnswer( session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), - Seq(Row(205)) - ) + Seq(Row(205))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -631,28 +601,23 @@ class StoredProcedureSuite extends SNTestBase { num12: Int, num13: Int, num14: Int, - num15: Int - ) => + num15: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Seq(Row(220)) - ) + Seq(Row(220))) checkAnswer( session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Seq(Row(220)) - ) + Seq(Row(220))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -681,28 +646,23 @@ class StoredProcedureSuite extends SNTestBase { num13: Int, num14: Int, num15: Int, - num16: Int - ) => + num16: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), - Seq(Row(236)) - ) + Seq(Row(236))) checkAnswer( session.storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), - Seq(Row(236)) - ) + Seq(Row(236))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -732,29 +692,24 @@ class StoredProcedureSuite extends SNTestBase { num14: Int, num15: Int, num16: Int, - num17: Int - ) => + num17: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), - Seq(Row(253)) - ) + Seq(Row(253))) checkAnswer( session .storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), - Seq(Row(253)) - ) + Seq(Row(253))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -785,30 +740,25 @@ class StoredProcedureSuite extends SNTestBase { num15: Int, num16: Int, num17: Int, - num18: Int - ) => + num18: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session .storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), - Seq(Row(271)) - ) + Seq(Row(271))) checkAnswer( session .storedProcedure(spName, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), - Seq(Row(271)) - ) + Seq(Row(271))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -840,18 +790,15 @@ class StoredProcedureSuite extends SNTestBase { num16: Int, num17: Int, num18: Int, - num19: Int - ) => + num19: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session .storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), - Seq(Row(290)) - ) + Seq(Row(290))) checkAnswer( session.storedProcedure( spName, @@ -873,17 +820,14 @@ class StoredProcedureSuite extends SNTestBase { 16, 17, 18, - 19 - ), - Seq(Row(290)) - ) + 19), + Seq(Row(290))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -916,14 +860,12 @@ class StoredProcedureSuite extends SNTestBase { num17: Int, num18: Int, num19: Int, - num20: Int - ) => + num20: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + num20 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session.storedProcedure( sp, @@ -946,10 +888,8 @@ class StoredProcedureSuite extends SNTestBase { 17, 18, 19, - 20 - ), - Seq(Row(310)) - ) + 20), + Seq(Row(310))) checkAnswer( session.storedProcedure( spName, @@ -972,17 +912,14 @@ class StoredProcedureSuite extends SNTestBase { 17, 18, 19, - 20 - ), - Seq(Row(310)) - ) + 20), + Seq(Row(310))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -1016,14 +953,12 @@ class StoredProcedureSuite extends SNTestBase { num18: Int, num19: Int, num20: Int, - num21: Int - ) => + num21: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + num20 + num21 + 100, stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer( session.storedProcedure( sp, @@ -1047,10 +982,8 @@ class StoredProcedureSuite extends SNTestBase { 18, 19, 20, - 21 - ), - Seq(Row(331)) - ) + 21), + Seq(Row(331))) checkAnswer( session.storedProcedure( spName, @@ -1074,17 +1007,14 @@ class StoredProcedureSuite extends SNTestBase { 18, 19, 20, - 21 - ), - Seq(Row(331)) - ) + 21), + Seq(Row(331))) } finally { dropStage(stageName) session .sql( s"drop procedure if exists $spName (INT,INT,INT,INT,INT,INT,INT," + - "INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)" - ) + "INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT,INT)") .show() } } @@ -1099,8 +1029,7 @@ class StoredProcedureSuite extends SNTestBase { spName, (_: Session) => s"SUCCESS", stageName, - isCallerMode = true - ) + isCallerMode = true) checkAnswer(session.storedProcedure(spName), Seq(Row("SUCCESS"))) // works in other sessions checkAnswer(newSession.storedProcedure(spName), Seq(Row("SUCCESS"))) @@ -1121,30 +1050,26 @@ class StoredProcedureSuite extends SNTestBase { spName1, (_: Session) => s"SUCCESS", stageName, - isCallerMode = true - ) + isCallerMode = true) import com.snowflake.snowpark.functions.col checkAnswer( session .sql(s"describe procedure $spName1()") .where(col(""""property"""") === "execute as") .select(col(""""value"""")), - Seq(Row("CALLER")) - ) + Seq(Row("CALLER"))) session.sproc.registerPermanent( spName2, (_: Session) => s"SUCCESS", stageName, - isCallerMode = false - ) + isCallerMode = false) checkAnswer( session .sql(s"describe procedure $spName2()") .where(col(""""property"""") === "execute as") .select(col(""""value"""")), - Seq(Row("OWNER")) - ) + Seq(Row("OWNER"))) } finally { dropStage(stageName) session.sql(s"drop procedure if exists $spName1()").show() @@ -1250,8 +1175,7 @@ println(s""" num5: Int, num6: Int, num7: Int, - num8: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + 100 + num8: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + 100 val sp = session.sproc.registerTemporary(func) val result = session.sproc.runLocally(func, 1, 2, 3, 4, 5, 6, 7, 8) assert(result == 136) @@ -1269,8 +1193,7 @@ println(s""" num6: Int, num7: Int, num8: Int, - num9: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + 100 + num9: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + 100 val sp = session.sproc.registerTemporary(func) val result = session.sproc.runLocally(func, 1, 2, 3, 4, 5, 6, 7, 8, 9) assert(result == 145) @@ -1289,8 +1212,7 @@ println(s""" num7: Int, num8: Int, num9: Int, - num10: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + 100 + num10: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + 100 val sp = session.sproc.registerTemporary(func) val result = session.sproc.runLocally(func, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) assert(result == 155) @@ -1310,8 +1232,8 @@ println(s""" num8: Int, num9: Int, num10: Int, - num11: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + 100 + num11: Int) => + num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + 100 val sp = session.sproc.registerTemporary(func) val result = session.sproc.runLocally(func, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11) assert(result == 166) @@ -1332,15 +1254,14 @@ println(s""" num9: Int, num10: Int, num11: Int, - num12: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + 100 + num12: Int) => + num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + 100 val sp = session.sproc.registerTemporary(func) val result = session.sproc.runLocally(func, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) assert(result == 178) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 13 args", JavaStoredProcExclude) { @@ -1358,8 +1279,7 @@ println(s""" num10: Int, num11: Int, num12: Int, - num13: Int - ) => + num13: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + 100 val sp = session.sproc.registerTemporary(func) @@ -1367,8 +1287,7 @@ println(s""" assert(result == 191) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 14 args", JavaStoredProcExclude) { @@ -1387,8 +1306,7 @@ println(s""" num11: Int, num12: Int, num13: Int, - num14: Int - ) => + num14: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + 100 val sp = session.sproc.registerTemporary(func) @@ -1396,8 +1314,7 @@ println(s""" assert(result == 205) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 15 args", JavaStoredProcExclude) { @@ -1417,8 +1334,7 @@ println(s""" num12: Int, num13: Int, num14: Int, - num15: Int - ) => + num15: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + 100 val sp = session.sproc.registerTemporary(func) @@ -1426,8 +1342,7 @@ println(s""" assert(result == 220) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 16 args", JavaStoredProcExclude) { @@ -1448,8 +1363,7 @@ println(s""" num13: Int, num14: Int, num15: Int, - num16: Int - ) => + num16: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + 100 val sp = session.sproc.registerTemporary(func) @@ -1458,8 +1372,7 @@ println(s""" assert(result == 236) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 17 args", JavaStoredProcExclude) { @@ -1481,8 +1394,7 @@ println(s""" num14: Int, num15: Int, num16: Int, - num17: Int - ) => + num17: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + 100 val sp = session.sproc.registerTemporary(func) @@ -1491,8 +1403,7 @@ println(s""" assert(result == 253) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 18 args", JavaStoredProcExclude) { @@ -1515,8 +1426,7 @@ println(s""" num15: Int, num16: Int, num17: Int, - num18: Int - ) => + num18: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + 100 val sp = session.sproc.registerTemporary(func) @@ -1525,8 +1435,7 @@ println(s""" assert(result == 271) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 19 args", JavaStoredProcExclude) { @@ -1550,8 +1459,7 @@ println(s""" num16: Int, num17: Int, num18: Int, - num19: Int - ) => + num19: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + 100 val sp = session.sproc.registerTemporary(func) @@ -1575,14 +1483,12 @@ println(s""" 16, 17, 18, - 19 - ) + 19) assert(result == 290) checkAnswer( session .storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 20 args", JavaStoredProcExclude) { @@ -1607,8 +1513,7 @@ println(s""" num17: Int, num18: Int, num19: Int, - num20: Int - ) => + num20: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + num20 + 100 val sp = session.sproc.registerTemporary(func) @@ -1633,14 +1538,12 @@ println(s""" 17, 18, 19, - 20 - ) + 20) assert(result == 310) checkAnswer( session .storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), - Seq(Row(result)) - ) + Seq(Row(result))) } test("anonymous temporary: 21 args", JavaStoredProcExclude) { @@ -1666,8 +1569,7 @@ println(s""" num18: Int, num19: Int, num20: Int, - num21: Int - ) => + num21: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + num20 + num21 + 100 @@ -1694,8 +1596,7 @@ println(s""" 18, 19, 20, - 21 - ) + 21) assert(result == 331) checkAnswer( session.storedProcedure( @@ -1720,18 +1621,15 @@ println(s""" 18, 19, 20, - 21 - ), - Seq(Row(result)) - ) + 21), + Seq(Row(result))) } test("named temporary: duplicated name") { val name = randomName() val sp1 = session.sproc.registerTemporary(name, (_: Session) => s"SP 1") val msg = intercept[SnowflakeSQLException]( - session.sproc.registerTemporary(name, (_: Session) => s"SP 2") - ).getMessage + session.sproc.registerTemporary(name, (_: Session) => s"SP 2")).getMessage assert(msg.contains("already exists")) } @@ -1779,8 +1677,7 @@ println(s""" val name = randomName() val sp = session.sproc.registerTemporary( name, - (_: Session, num1: Int, num2: Int, num3: Int) => num1 + num2 + num3 + 100 - ) + (_: Session, num1: Int, num2: Int, num3: Int) => num1 + num2 + num3 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3), Seq(Row(106))) checkAnswer(session.storedProcedure(name, 1, 2, 3), Seq(Row(106))) } @@ -1789,8 +1686,7 @@ println(s""" val name = randomName() val sp = session.sproc.registerTemporary( name, - (_: Session, num1: Int, num2: Int, num3: Int, num4: Int) => num1 + num2 + num3 + num4 + 100 - ) + (_: Session, num1: Int, num2: Int, num3: Int, num4: Int) => num1 + num2 + num3 + num4 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4), Seq(Row(110))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4), Seq(Row(110))) } @@ -1800,8 +1696,7 @@ println(s""" val sp = session.sproc.registerTemporary( name, (_: Session, num1: Int, num2: Int, num3: Int, num4: Int, num5: Int) => - num1 + num2 + num3 + num4 + num5 + 100 - ) + num1 + num2 + num3 + num4 + num5 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5), Seq(Row(115))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5), Seq(Row(115))) } @@ -1811,8 +1706,7 @@ println(s""" val sp = session.sproc.registerTemporary( name, (_: Session, num1: Int, num2: Int, num3: Int, num4: Int, num5: Int, num6: Int) => - num1 + num2 + num3 + num4 + num5 + num6 + 100 - ) + num1 + num2 + num3 + num4 + num5 + num6 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6), Seq(Row(121))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5, 6), Seq(Row(121))) } @@ -1822,8 +1716,7 @@ println(s""" val sp = session.sproc.registerTemporary( name, (_: Session, num1: Int, num2: Int, num3: Int, num4: Int, num5: Int, num6: Int, num7: Int) => - num1 + num2 + num3 + num4 + num5 + num6 + num7 + 100 - ) + num1 + num2 + num3 + num4 + num5 + num6 + num7 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7), Seq(Row(128))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7), Seq(Row(128))) } @@ -1841,9 +1734,7 @@ println(s""" num5: Int, num6: Int, num7: Int, - num8: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + 100 - ) + num8: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8), Seq(Row(136))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8), Seq(Row(136))) } @@ -1862,9 +1753,7 @@ println(s""" num6: Int, num7: Int, num8: Int, - num9: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + 100 - ) + num9: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9), Seq(Row(145))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9), Seq(Row(145))) } @@ -1884,9 +1773,7 @@ println(s""" num7: Int, num8: Int, num9: Int, - num10: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + 100 - ) + num10: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Seq(Row(155))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Seq(Row(155))) } @@ -1907,9 +1794,8 @@ println(s""" num8: Int, num9: Int, num10: Int, - num11: Int - ) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + 100 - ) + num11: Int) => + num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), Seq(Row(166))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), Seq(Row(166))) } @@ -1931,10 +1817,8 @@ println(s""" num9: Int, num10: Int, num11: Int, - num12: Int - ) => - num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + 100 - ) + num12: Int) => + num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + 100) checkAnswer(session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), Seq(Row(178))) checkAnswer(session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), Seq(Row(178))) } @@ -1957,19 +1841,15 @@ println(s""" num10: Int, num11: Int, num12: Int, - num13: Int - ) => + num13: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + - num10 + num11 + num12 + num13 + 100 - ) + num10 + num11 + num12 + num13 + 100) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), - Seq(Row(191)) - ) + Seq(Row(191))) checkAnswer( session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), - Seq(Row(191)) - ) + Seq(Row(191))) } test("named temporary: 14 args", JavaStoredProcExclude) { @@ -1991,19 +1871,15 @@ println(s""" num11: Int, num12: Int, num13: Int, - num14: Int - ) => + num14: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + - num10 + num11 + num12 + num13 + num14 + 100 - ) + num10 + num11 + num12 + num13 + num14 + 100) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), - Seq(Row(205)) - ) + Seq(Row(205))) checkAnswer( session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), - Seq(Row(205)) - ) + Seq(Row(205))) } test("named temporary: 15 args", JavaStoredProcExclude) { @@ -2026,19 +1902,15 @@ println(s""" num12: Int, num13: Int, num14: Int, - num15: Int - ) => + num15: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + - num10 + num11 + num12 + num13 + num14 + num15 + 100 - ) + num10 + num11 + num12 + num13 + num14 + num15 + 100) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Seq(Row(220)) - ) + Seq(Row(220))) checkAnswer( session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), - Seq(Row(220)) - ) + Seq(Row(220))) } test("named temporary: 16 args", JavaStoredProcExclude) { @@ -2062,19 +1934,15 @@ println(s""" num13: Int, num14: Int, num15: Int, - num16: Int - ) => + num16: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + - num10 + num11 + num12 + num13 + num14 + num15 + num16 + 100 - ) + num10 + num11 + num12 + num13 + num14 + num15 + num16 + 100) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), - Seq(Row(236)) - ) + Seq(Row(236))) checkAnswer( session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), - Seq(Row(236)) - ) + Seq(Row(236))) } test("named temporary: 17 args", JavaStoredProcExclude) { @@ -2099,19 +1967,15 @@ println(s""" num14: Int, num15: Int, num16: Int, - num17: Int - ) => + num17: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + - num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + 100 - ) + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + 100) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), - Seq(Row(253)) - ) + Seq(Row(253))) checkAnswer( session.storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), - Seq(Row(253)) - ) + Seq(Row(253))) } test("named temporary: 18 args", JavaStoredProcExclude) { @@ -2137,20 +2001,16 @@ println(s""" num15: Int, num16: Int, num17: Int, - num18: Int - ) => + num18: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + - num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + 100 - ) + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + 100) checkAnswer( session.storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), - Seq(Row(271)) - ) + Seq(Row(271))) checkAnswer( session .storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), - Seq(Row(271)) - ) + Seq(Row(271))) } test("named temporary: 19 args", JavaStoredProcExclude) { @@ -2177,21 +2037,17 @@ println(s""" num16: Int, num17: Int, num18: Int, - num19: Int - ) => + num19: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + - num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + 100 - ) + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + num18 + num19 + 100) checkAnswer( session .storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), - Seq(Row(290)) - ) + Seq(Row(290))) checkAnswer( session .storedProcedure(name, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), - Seq(Row(290)) - ) + Seq(Row(290))) } test("named temporary: 20 args", JavaStoredProcExclude) { @@ -2219,17 +2075,14 @@ println(s""" num17: Int, num18: Int, num19: Int, - num20: Int - ) => + num20: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + - num18 + num19 + num20 + 100 - ) + num18 + num19 + num20 + 100) checkAnswer( session .storedProcedure(sp, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), - Seq(Row(310)) - ) + Seq(Row(310))) checkAnswer( session.storedProcedure( name, @@ -2252,10 +2105,8 @@ println(s""" 17, 18, 19, - 20 - ), - Seq(Row(310)) - ) + 20), + Seq(Row(310))) } test("named temporary: 21 args", JavaStoredProcExclude) { @@ -2284,12 +2135,10 @@ println(s""" num18: Int, num19: Int, num20: Int, - num21: Int - ) => + num21: Int) => num1 + num2 + num3 + num4 + num5 + num6 + num7 + num8 + num9 + num10 + num11 + num12 + num13 + num14 + num15 + num16 + num17 + - num18 + num19 + num20 + num21 + 100 - ) + num18 + num19 + num20 + num21 + 100) checkAnswer( session.storedProcedure( sp, @@ -2313,10 +2162,8 @@ println(s""" 18, 19, 20, - 21 - ), - Seq(Row(331)) - ) + 21), + Seq(Row(331))) checkAnswer( session.storedProcedure( name, @@ -2340,10 +2187,8 @@ println(s""" 18, 19, 20, - 21 - ), - Seq(Row(331)) - ) + 21), + Seq(Row(331))) } test("temp is temp") { diff --git a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala index b15e1575..e54e291e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala @@ -12,30 +12,25 @@ class TableFunctionSuite extends TestData { checkAnswer( df.join(tableFunctions.flatten, Map("input" -> parse_json(df("a")))).select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4")), - sort = false - ) + sort = false) checkAnswer( df.join(TableFunction("flatten"), Map("input" -> parse_json(df("a")))) .select("value"), Seq(Row("1"), Row("2"), Row("3"), Row("4")), - sort = false - ) + sort = false) checkAnswer( df.join(tableFunctions.split_to_table, df("a"), lit(",")).select("value"), Seq(Row("[1"), Row("2]"), Row("[3"), Row("4]")), - sort = false - ) + sort = false) checkAnswer( df.join(tableFunctions.split_to_table, Seq(df("a"), lit(","))).select("value"), Seq(Row("[1"), Row("2]"), Row("[3"), Row("4]")), - sort = false - ) + sort = false) checkAnswer( df.join(TableFunction("split_to_table"), df("a"), lit(",")).select("value"), Seq(Row("[1"), Row("2]"), Row("[3"), Row("4]")), - sort = false - ) + sort = false) } test("session table functions") { @@ -44,37 +39,32 @@ class TableFunctionSuite extends TestData { .tableFunction(tableFunctions.flatten, Map("input" -> parse_json(lit("[1,2]")))) .select("value"), Seq(Row("1"), Row("2")), - sort = false - ) + sort = false) checkAnswer( session .tableFunction(TableFunction("flatten"), Map("input" -> parse_json(lit("[1,2]")))) .select("value"), Seq(Row("1"), Row("2")), - sort = false - ) + sort = false) checkAnswer( session .tableFunction(tableFunctions.split_to_table, lit("split by space"), lit(" ")) .select("value"), Seq(Row("split"), Row("by"), Row("space")), - sort = false - ) + sort = false) checkAnswer( session .tableFunction(tableFunctions.split_to_table, Seq(lit("split by space"), lit(" "))) .select("value"), Seq(Row("split"), Row("by"), Row("space")), - sort = false - ) + sort = false) checkAnswer( session .tableFunction(TableFunction("split_to_table"), lit("split by space"), lit(" ")) .select("value"), Seq(Row("split"), Row("by"), Row("space")), - sort = false - ) + sort = false) } test("session table functions with dataframe columns") { @@ -84,15 +74,13 @@ class TableFunctionSuite extends TestData { .tableFunction(tableFunctions.split_to_table, Seq(df("a"), lit(" "))) .select("value"), Seq(Row("split"), Row("by"), Row("space")), - sort = false - ) + sort = false) checkAnswer( session .tableFunction(TableFunction("split_to_table"), Seq(df("a"), lit(" "))) .select("value"), Seq(Row("split"), Row("by"), Row("space")), - sort = false - ) + sort = false) val df2 = Seq(("[1,2]", "[5,6]"), ("[3,4]", "[7,8]")).toDF(Seq("a", "b")) checkAnswer( @@ -100,15 +88,13 @@ class TableFunctionSuite extends TestData { .tableFunction(tableFunctions.flatten, Map("input" -> parse_json(df2("b")))) .select("value"), Seq(Row("5"), Row("6"), Row("7"), Row("8")), - sort = false - ) + sort = false) checkAnswer( session .tableFunction(TableFunction("flatten"), Map("input" -> parse_json(df2("b")))) .select("value"), Seq(Row("5"), Row("6"), Row("7"), Row("8")), - sort = false - ) + sort = false) val df3 = Seq("[9, 10]").toDF("c") val dfJoined = df2.join(df3) @@ -117,8 +103,7 @@ class TableFunctionSuite extends TestData { .tableFunction(tableFunctions.flatten, Map("input" -> parse_json(dfJoined("b")))) .select("value"), Seq(Row("5"), Row("6"), Row("7"), Row("8")), - sort = false - ) + sort = false) val tableName = randomName() try { @@ -129,8 +114,7 @@ class TableFunctionSuite extends TestData { .tableFunction(tableFunctions.split_to_table, Seq(df4("a"), lit(" "))) .select("value"), Seq(Row("split"), Row("by"), Row("space")), - sort = false - ) + sort = false) } finally { dropTable(tableName) } @@ -146,8 +130,7 @@ class TableFunctionSuite extends TestData { val flattened = table.flatten(table("value")) checkAnswer( flattened.select(table("value"), flattened("value").as("newValue")), - Seq(Row("[\n \"a\",\n \"b\"\n]", "\"a\""), Row("[\n \"a\",\n \"b\"\n]", "\"b\"")) - ) + Seq(Row("[\n \"a\",\n \"b\"\n]", "\"a\""), Row("[\n \"a\",\n \"b\"\n]", "\"b\""))) } finally { dropTable(tableName) } @@ -160,8 +143,7 @@ class TableFunctionSuite extends TestData { "Obs1\t-0.74\t-0.2\t0.3", "Obs2\t5442\t0.19\t0.16", "Obs3\t0.34\t0.46\t0.72", - "Obs4\t-0.15\t0.71\t0.13" - ).toDF("line") + "Obs4\t-0.15\t0.71\t0.13").toDF("line") val flattenedMatrix = df .withColumn("rowLabel", split(col("line"), lit("\t"))(0)) @@ -198,28 +180,22 @@ class TableFunctionSuite extends TestData { ||"Obs3" |Sample3 |0.72 | ||"Obs4" |Sample3 |0.13 | |---------------------------------------- - |""".stripMargin - ) + |""".stripMargin) } test("Argument in table function: flatten") { val df = Seq( (1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), - (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1")) - ).toDF("idx", "arr", "map") + (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") checkAnswer( df.join(tableFunctions.flatten(df("arr"))) .select("value"), - Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33")) - ) + Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33"))) // error if it is not a table function val error1 = intercept[SnowparkClientException] { df.join(lit("dummy")) } assert( - error1.message.contains( - "Unsupported join operations, Dataframes can join " + - "with other Dataframes or TableFunctions only" - ) - ) + error1.message.contains("Unsupported join operations, Dataframes can join " + + "with other Dataframes or TableFunctions only")) } test("Argument in table function: flatten2") { @@ -232,12 +208,9 @@ class TableFunctionSuite extends TestData { path = "b", outer = true, recursive = true, - mode = "both" - ) - ) + mode = "both")) .select("value"), - Seq(Row("77"), Row("88")) - ) + Seq(Row("77"), Row("88"))) val df2 = Seq("[]").toDF("col") checkAnswer( @@ -248,12 +221,9 @@ class TableFunctionSuite extends TestData { path = "", outer = true, recursive = true, - mode = "both" - ) - ) + mode = "both")) .select("value"), - Seq(Row(null)) - ) + Seq(Row(null))) assert( df1 @@ -263,11 +233,8 @@ class TableFunctionSuite extends TestData { path = "", outer = true, recursive = true, - mode = "both" - ) - ) - .count() == 4 - ) + mode = "both")) + .count() == 4) assert( df1 .join( @@ -276,11 +243,8 @@ class TableFunctionSuite extends TestData { path = "", outer = true, recursive = false, - mode = "both" - ) - ) - .count() == 2 - ) + mode = "both")) + .count() == 2) assert( df1 .join( @@ -289,11 +253,8 @@ class TableFunctionSuite extends TestData { path = "", outer = true, recursive = true, - mode = "array" - ) - ) - .count() == 1 - ) + mode = "array")) + .count() == 1) assert( df1 .join( @@ -302,32 +263,24 @@ class TableFunctionSuite extends TestData { path = "", outer = true, recursive = true, - mode = "object" - ) - ) - .count() == 2 - ) + mode = "object")) + .count() == 2) } test("Argument in table function: flatten - session") { val df = Seq( (1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")), - (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1")) - ).toDF("idx", "arr", "map") + (2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map") checkAnswer( session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"), - Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33")) - ) + Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33"))) // error if it is not a table function val error1 = intercept[SnowparkClientException] { session.tableFunction(lit("dummy")) } assert( - error1.message.contains( - "Invalid input argument, " + - "Session.tableFunction only supports table function arguments" - ) - ) + error1.message.contains("Invalid input argument, " + + "Session.tableFunction only supports table function arguments")) } test("Argument in table function: flatten - session 2") { @@ -340,12 +293,9 @@ class TableFunctionSuite extends TestData { path = "b", outer = true, recursive = true, - mode = "both" - ) - ) + mode = "both")) .select("value"), - Seq(Row("77"), Row("88")) - ) + Seq(Row("77"), Row("88"))) } test("Argument in table function: split_to_table") { @@ -353,15 +303,13 @@ class TableFunctionSuite extends TestData { checkAnswer( df.join(tableFunctions.split_to_table(df("data"), ",")).select("value"), - Seq(Row("1"), Row("2"), Row("3"), Row("4")) - ) + Seq(Row("1"), Row("2"), Row("3"), Row("4"))) checkAnswer( session .tableFunction(tableFunctions.split_to_table(df("data"), ",")) .select("value"), - Seq(Row("1"), Row("2"), Row("3"), Row("4")) - ) + Seq(Row("1"), Row("2"), Row("3"), Row("4"))) } test("Argument in table function: table function") { @@ -370,8 +318,7 @@ class TableFunctionSuite extends TestData { checkAnswer( df.join(TableFunction("split_to_table")(df("data"), lit(","))) .select("value"), - Seq(Row("1"), Row("2"), Row("3"), Row("4")) - ) + Seq(Row("1"), Row("2"), Row("3"), Row("4"))) val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col") checkAnswer( @@ -383,13 +330,9 @@ class TableFunctionSuite extends TestData { "path" -> lit("b"), "outer" -> lit(true), "recursive" -> lit(true), - "mode" -> lit("both") - ) - ) - ) + "mode" -> lit("both")))) .select("value"), - Seq(Row("77"), Row("88")) - ) + Seq(Row("77"), Row("88"))) } test("table function in select") { @@ -404,23 +347,20 @@ class TableFunctionSuite extends TestData { assert(result2.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE")) checkAnswer( result2, - Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4")) - ) + Seq(Row(1, 1, 1, "1"), Row(1, 1, 2, "2"), Row(2, 2, 1, "3"), Row(2, 2, 2, "4"))) // columns + tf + columns val result3 = df.select(df("idx"), tableFunctions.split_to_table(df("data"), ","), df("idx")) assert(result3.schema.map(_.name) == Seq("IDX", "SEQ", "INDEX", "VALUE", "IDX")) checkAnswer( result3, - Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2)) - ) + Seq(Row(1, 1, 1, "1", 1), Row(1, 1, 2, "2", 1), Row(2, 2, 1, "3", 2), Row(2, 2, 2, "4", 2))) // tf + other express val result4 = df.select(tableFunctions.split_to_table(df("data"), ","), df("idx") + 100) checkAnswer( result4, - Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102)) - ) + Seq(Row(1, 1, "1", 101), Row(1, 2, "2", 101), Row(2, 1, "3", 102), Row(2, 2, "4", 102))) } test("table function join with duplicated column name") { @@ -448,8 +388,7 @@ class TableFunctionSuite extends TestData { val df1 = df.select(parse_json(df("a")).cast(types.ArrayType(types.IntegerType)).as("a")) checkAnswer( df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")(1)), - Seq(Row(1, "1", "2"), Row(1, "2", "2")) - ) + Seq(Row(1, "1", "2"), Row(1, "2", "2"))) } test("explode with map column") { @@ -457,28 +396,23 @@ class TableFunctionSuite extends TestData { val df1 = df.select( parse_json(df("a")) .cast(types.MapType(types.StringType, types.IntegerType)) - .as("a") - ) + .as("a")) checkAnswer( df1.select(lit(1), tableFunctions.explode(df1("a")), df1("a")("a")), - Seq(Row(1, "a", "1", "1"), Row(1, "b", "2", "1")) - ) + Seq(Row(1, "a", "1", "1"), Row(1, "b", "2", "1"))) } test("explode with other column") { val df = Seq("""{"a":1, "b": 2}""").toDF("a") val df1 = df.select( parse_json(df("a")) - .as("a") - ) + .as("a")) val error = intercept[SnowparkClientException] { df1.select(tableFunctions.explode(df1("a"))).show() } assert( error.message.contains( - "the input argument type of Explode function should be either Map or Array types" - ) - ) + "the input argument type of Explode function should be either Map or Array types")) assert(error.message.contains("The input argument type: Variant")) } @@ -494,21 +428,17 @@ class TableFunctionSuite extends TestData { val df1 = df.select( parse_json(df("a")) .cast(types.MapType(types.StringType, types.IntegerType)) - .as("a") - ) + .as("a")) checkAnswer( session.tableFunction(tableFunctions.explode(df1("a"))), - Seq(Row("a", "1"), Row("b", "2")) - ) + Seq(Row("a", "1"), Row("b", "2"))) // with literal value checkAnswer( session.tableFunction( tableFunctions - .explode(parse_json(lit("[1, 2]")).cast(types.ArrayType(types.IntegerType))) - ), - Seq(Row("1"), Row("2")) - ) + .explode(parse_json(lit("[1, 2]")).cast(types.ArrayType(types.IntegerType)))), + Seq(Row("1"), Row("2"))) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala index 134253c2..3a9d8ef5 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala @@ -36,14 +36,12 @@ class TableSuite extends TestData { s"insert into $semiStructuredTable select parse_json(a), parse_json(b), " + s"parse_json(a), to_geography(c) from values('[1,2]', '{a:1}', 'POINT(-122.35 37.55)')," + s"('[1,2,3]', '{b:2}', 'POINT(-12 37)') as T(a,b,c)", - session - ) + session) createTable(timeTable, "time time") runQuery( s"insert into $timeTable select to_time(a) from values('09:15:29')," + s"('09:15:29.99999999') as T(a)", - session - ) + session) } override def afterAll: Unit = { @@ -111,8 +109,7 @@ class TableSuite extends TestData { // ErrorIfExists Mode assertThrows[SnowflakeSQLException]( - df.write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName2) - ) + df.write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName2)) } test("Save as Snowflake Table - String Argument") { @@ -209,9 +206,7 @@ class TableSuite extends TestData { ArrayType(StringType), MapType(StringType, StringType), VariantType, - GeographyType - ) - ) + GeographyType)) checkAnswer( df, Seq( @@ -220,20 +215,14 @@ class TableSuite extends TestData { "{\n \"a\": 1\n}", "[\n 1,\n 2\n]", Geography.fromGeoJSON( - "{\n \"coordinates\": [\n -122.35,\n 37.55\n ],\n \"type\": \"Point\"\n}" - ) - ), + "{\n \"coordinates\": [\n -122.35,\n 37.55\n ],\n \"type\": \"Point\"\n}")), Row( "[\n 1,\n 2,\n 3\n]", "{\n \"b\": 2\n}", "[\n 1,\n 2,\n 3\n]", Geography.fromGeoJSON( - "{\n \"coordinates\": [\n -12,\n 37\n ],\n \"type\": \"Point\"\n}" - ) - ) - ), - sort = false - ) + "{\n \"coordinates\": [\n -12,\n 37\n ],\n \"type\": \"Point\"\n}"))), + sort = false) } // Contains 'alter session', which is not supported by owner's right java sp @@ -241,8 +230,7 @@ class TableSuite extends TestData { val df2 = session.table(semiStructuredTable).select(col("g1")) assert( df2.collect()(0).getString(0) == - "{\n \"coordinates\": [\n -122.35,\n 37.55\n ],\n \"type\": \"Point\"\n}" - ) + "{\n \"coordinates\": [\n -122.35,\n 37.55\n ],\n \"type\": \"Point\"\n}") assertThrows[ClassCastException](df2.collect()(0).getBinary(0)) assert( getShowString(df2, 10) == @@ -264,16 +252,14 @@ class TableSuite extends TestData { || "type": "Point" | ||} | |---------------------- - |""".stripMargin - ) + |""".stripMargin) testWithAlteredSessionParameter( { assertThrows[SnowparkClientException](df2.collect()) }, "GEOGRAPHY_OUTPUT_FORMAT", - "'WKT'" - ) + "'WKT'") } test("table with time type") { @@ -282,8 +268,7 @@ class TableSuite extends TestData { checkAnswer( df2, Seq(Row(Time.valueOf("09:15:29")), Row(Time.valueOf("09:15:29"))), - sort = false - ) + sort = false) } // getDatabaseFromProperties will read local files, which is not supported in Java SP yet. diff --git a/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala index d245b154..278c0ad7 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala @@ -29,7 +29,7 @@ trait UDFSuite extends TestData { override def equals(other: Any): Boolean = { other match { case o: NonSerializable => id == o.id - case _ => false + case _ => false } } } @@ -116,13 +116,11 @@ trait UDFSuite extends TestData { runQuery( s"insert into $semiStructuredTable" + s" select (object_construct('1', 'one', '2', 'two'))", - session - ) + session) runQuery( s"insert into $semiStructuredTable" + s" select (object_construct('10', 'ten', '20', 'twenty'))", - session - ) + session) val df = session.table(semiStructuredTable) val mapKeysUdf = udf((x: mutable.Map[String, String]) => x.keys.toArray) @@ -154,15 +152,13 @@ trait UDFSuite extends TestData { s" ( select object_construct('1', 'one', '2', 'two')," + s" object_construct('one', '10', 'two', '20')," + s" 'ID1')", - session - ) + session) runQuery( s"insert into $semiStructuredTable" + s" ( select object_construct('3', 'three', '4', 'four')," + s" object_construct('three', '30', 'four', '40')," + s" 'ID2')", - session - ) + session) val df = session.table(semiStructuredTable) val mapUdf = @@ -193,8 +189,7 @@ trait UDFSuite extends TestData { val replaceUdf = udf((elem: String) => elem.replaceAll("num", "id")) checkAnswer( df1.select(replaceUdf($"a")), - Seq(Row("{\"id\":1,\"str\":\"str1\"}"), Row("{\"id\":2,\"str\":\"str2\"}")) - ) + Seq(Row("{\"id\":1,\"str\":\"str1\"}"), Row("{\"id\":2,\"str\":\"str2\"}"))) } test("view with UDF") { @@ -213,8 +208,7 @@ trait UDFSuite extends TestData { val stringUdf = udf((x: Int) => s"$prefix$x") checkAnswer( df.withColumn("b", stringUdf(col("a"))), - Seq(Row(1, "Hello1"), Row(2, "Hello2"), Row(3, "Hello3")) - ) + Seq(Row(1, "Hello1"), Row(2, "Hello2"), Row(3, "Hello3"))) } test("test large closure", JavaStoredProcAWSOnly) { @@ -236,8 +230,7 @@ trait UDFSuite extends TestData { val stringUdf = udf((x: Int) => new java.lang.String(s"$prefix$x")) checkAnswer( df.withColumn("b", stringUdf(col("a"))), - Seq(Row(1, "Hello1"), Row(2, "Hello2"), Row(3, "Hello3")) - ) + Seq(Row(1, "Hello1"), Row(2, "Hello2"), Row(3, "Hello3"))) } test("UDF function with multiple columns") { @@ -245,8 +238,7 @@ trait UDFSuite extends TestData { val sumUDF = udf((x: Int, y: Int) => x + y) checkAnswer( df.withColumn("c", sumUDF(col("a"), col("b"))), - Seq(Row(1, 2, 3), Row(2, 3, 5), Row(3, 4, 7)) - ) + Seq(Row(1, 2, 3), Row(2, 3, 5), Row(3, 4, 7))) } test("Incorrect number of args") { @@ -266,10 +258,8 @@ trait UDFSuite extends TestData { checkAnswer( df.withColumn( "c", - callUDF(s"${session.getFullyQualifiedCurrentSchema}.$functionName", col("a")) - ), - Seq(Row(1, 2), Row(2, 4), Row(3, 6)) - ) + callUDF(s"${session.getFullyQualifiedCurrentSchema}.$functionName", col("a"))), + Seq(Row(1, 2), Row(2, 4), Row(3, 6))) } test("Test for Long data type") { @@ -333,8 +323,7 @@ trait UDFSuite extends TestData { val UDF = udf((a: Option[Boolean], b: Option[Boolean]) => a == b) checkAnswer( df.withColumn("c", UDF($"a", $"b")).select($"c"), - Seq(Row(true), Row(false), Row(true)) - ) + Seq(Row(true), Row(false), Row(true))) } test("Test for double data type") { @@ -343,8 +332,7 @@ trait UDFSuite extends TestData { assert( df.withColumn("c", UDF(col("a"))).collect() sameElements - Array[Row](Row(1.01, 2.02), Row(2.01, 4.02), Row(3.01, 6.02)) - ) + Array[Row](Row(1.01, 2.02), Row(2.01, 4.02), Row(3.01, 6.02))) } test("Test for boolean data type") { @@ -352,8 +340,7 @@ trait UDFSuite extends TestData { val UDF = udf((a: Int, b: Int) => a == b) checkAnswer( df.withColumn("c", UDF($"a", $"b")).select($"c"), - Seq(Row(true), Row(true), Row(false)) - ) + Seq(Row(true), Row(true), Row(false))) } test("Test for binary data type") { @@ -375,8 +362,7 @@ trait UDFSuite extends TestData { val input = Seq( (Date.valueOf("2019-01-01"), Timestamp.valueOf("2019-01-01 00:00:00")), (Date.valueOf("2020-01-01"), Timestamp.valueOf("2020-01-01 00:00:00")), - (null, null) - ) + (null, null)) val out = input.map { case (null, null) => Row(null, null) case (a, b) => @@ -391,8 +377,7 @@ trait UDFSuite extends TestData { .withColumn("c", toSNUDF(col("date"))) .withColumn("d", toDateUDF(col("timestamp"))) .select($"c", $"d"), - out - ) + out) } test("Test for time, date, timestamp with snowflake timezone") { @@ -410,8 +395,7 @@ trait UDFSuite extends TestData { df.select(addUDF(col("col1"))) .collect()(0) .getTimestamp(0) - .toString == "2020-01-01 00:00:05.0" - ) + .toString == "2020-01-01 00:00:05.0") } test("Test for Geography data type") { @@ -428,8 +412,7 @@ trait UDFSuite extends TestData { } else { Geography.fromGeoJSON( g.asGeoJSON() - .replace("0", "") - ) + .replace("0", "")) } } }) @@ -439,17 +422,11 @@ trait UDFSuite extends TestData { Seq( Row( Geography.fromGeoJSON( - "{\n \"coordinates\": [\n 3,\n 1\n ],\n \"type\": \"Point\"\n}" - ) - ), + "{\n \"coordinates\": [\n 3,\n 1\n ],\n \"type\": \"Point\"\n}")), Row( Geography.fromGeoJSON( - "{\n \"coordinates\": [\n 50,\n 60\n ],\n \"type\": \"Point\"\n}" - ) - ), - Row(null) - ) - ) + "{\n \"coordinates\": [\n 50,\n 60\n ],\n \"type\": \"Point\"\n}")), + Row(null))) } test("Test for Geometry data type") { @@ -465,8 +442,7 @@ trait UDFSuite extends TestData { Geometry.fromGeoJSON(g.toString) } else { Geometry.fromGeoJSON( - "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}" - ) + "{\"coordinates\": [3.000000000000000e+01,1.000000000000000e+01],\"type\": \"Point\"}") } } }) @@ -488,9 +464,7 @@ trait UDFSuite extends TestData { | ], | "type": "Point" |}""".stripMargin)), - Row(null) - ) - ) + Row(null))) } // Excluding this test for known Timezone issue in stored proc @@ -507,8 +481,7 @@ trait UDFSuite extends TestData { // '2017-02-24 12:00:00.456' -> '2017-02-24 12:00:05.456' checkAnswer( variant1.select(variantTimestampUDF(col("timestamp_ntz1"))), - Seq(Row(Timestamp.valueOf("2017-02-24 20:00:05.456"))) - ) + Seq(Row(Timestamp.valueOf("2017-02-24 20:00:05.456")))) } // Excluding this test for known Timezone issue in stored proc @@ -520,8 +493,7 @@ trait UDFSuite extends TestData { // so 20:57:01 -> 04:57:01 + one day. '1970-01-02 04:57:01.0' -> '1970-01-02 04:57:06.0' checkAnswer( variant1.select(variantTimeUDF(col("time1"))), - Seq(Row(Timestamp.valueOf("1970-01-02 04:57:06.0"))) - ) + Seq(Row(Timestamp.valueOf("1970-01-02 04:57:06.0")))) } // Excluding this test for known Timezone issue in stored proc @@ -533,8 +505,7 @@ trait UDFSuite extends TestData { // so 2017-02-24 -> 2017-02-24 08:00:00. '2017-02-24 08:00:00' -> '2017-02-24 08:00:05' checkAnswer( variant1.select(variantUDF(col("date1"))), - Seq(Row(Timestamp.valueOf("2017-02-24 08:00:05.0"))) - ) + Seq(Row(Timestamp.valueOf("2017-02-24 08:00:05.0")))) } test("Test for Variant String input") { @@ -575,8 +546,7 @@ trait UDFSuite extends TestData { checkAnswer( nullJson1.select(variantNullInputUDF(col("v"))), Seq(Row("null"), Row("\"foo\""), Row(null)), - sort = false - ) + sort = false) } test("Test for string Variant output") { @@ -626,8 +596,7 @@ trait UDFSuite extends TestData { udf((_: Variant) => new Variant(Timestamp.valueOf("2020-10-10 01:02:03"))) checkAnswer( variant1.select(variantOutputUDF(col("num1"))), - Seq(Row("\"2020-10-10 01:02:03.0\"")) - ) + Seq(Row("\"2020-10-10 01:02:03.0\""))) } test("Test for Array[Variant]") { @@ -638,14 +607,12 @@ trait UDFSuite extends TestData { // strip \" from it. todo: SNOW-254551 checkAnswer( variant1.select(variantUDF(col("arr1"))), - Seq(Row("[\n \"\\\"Example\\\"\",\n \"1\"\n]")) - ) + Seq(Row("[\n \"\\\"Example\\\"\",\n \"1\"\n]"))) variantUDF = udf((v: Array[Variant]) => v ++ Array(null)) checkAnswer( variant1.select(variantUDF(col("arr1"))), - Seq(Row("[\n \"\\\"Example\\\"\",\n undefined\n]")) - ) + Seq(Row("[\n \"\\\"Example\\\"\",\n undefined\n]"))) // UDF that returns null. Need the if ... else ... to define a return type. variantUDF = udf((v: Array[Variant]) => if (true) null else Array(new Variant(1))) @@ -656,19 +623,16 @@ trait UDFSuite extends TestData { var variantUDF = udf((v: mutable.Map[String, Variant]) => v + ("a" -> new Variant(1))) checkAnswer( variant1.select(variantUDF(col("obj1"))), - Seq(Row("{\n \"Tree\": \"\\\"Pine\\\"\",\n \"a\": \"1\"\n}")) - ) + Seq(Row("{\n \"Tree\": \"\\\"Pine\\\"\",\n \"a\": \"1\"\n}"))) variantUDF = udf((v: mutable.Map[String, Variant]) => v + ("a" -> null)) checkAnswer( variant1.select(variantUDF(col("obj1"))), - Seq(Row("{\n \"Tree\": \"\\\"Pine\\\"\",\n \"a\": null\n}")) - ) + Seq(Row("{\n \"Tree\": \"\\\"Pine\\\"\",\n \"a\": null\n}"))) // UDF that returns null. Need the if ... else ... to define a return type. variantUDF = udf((v: mutable.Map[String, Variant]) => - if (true) null else mutable.Map[String, Variant]("a" -> new Variant(1)) - ) + if (true) null else mutable.Map[String, Variant]("a" -> new Variant(1))) checkAnswer(variant1.select(variantUDF(col("obj1"))), Seq(Row(null))) } @@ -678,8 +642,7 @@ trait UDFSuite extends TestData { runQuery( s"insert into $tableName select to_time(a), to_timestamp(b) from values('01:02:03', " + s"'1970-01-01 01:02:03'),(null, null) as T(a, b)", - session - ) + session) val times = session.table(tableName) val toSNUDF = udf((x: Time) => if (x == null) null else new Timestamp(x.getTime)) val toTimeUDF = udf((x: Timestamp) => if (x == null) null else new Time(x.getTime)) @@ -689,8 +652,7 @@ trait UDFSuite extends TestData { .withColumn("c", toSNUDF(col("time"))) .withColumn("d", toTimeUDF(col("timestamp"))) .select($"c", $"d"), - Seq(Row(Timestamp.valueOf("1970-01-01 01:02:03"), Time.valueOf("01:02:03")), Row(null, null)) - ) + Seq(Row(Timestamp.valueOf("1970-01-01 01:02:03"), Time.valueOf("01:02:03")), Row(null, null))) } // Excluding the tests for 2 to 22 args from stored procs to limit overall time // of the UDFSuite run as a regression test @@ -738,8 +700,7 @@ trait UDFSuite extends TestData { checkAnswer( df.withColumn("res", callUDF(funcName, col("c1"), col("c2"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) } test("Test for num args : 3", JavaStoredProcExclude) { @@ -752,17 +713,14 @@ trait UDFSuite extends TestData { session.udf.registerTemporary(funcName, (c1: Int, c2: Int, c3: Int) => c1 + c2 + c3) checkAnswer( df.withColumn("res", sum(col("c1"), col("c2"), col("c3"))).select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn("res", sum1(col("c1"), col("c2"), col("c3"))).select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn("res", callUDF(funcName, col("c1"), col("c2"), col("c3"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) } test("Test for num args : 4", JavaStoredProcExclude) { @@ -777,17 +735,14 @@ trait UDFSuite extends TestData { .registerTemporary(funcName, (c1: Int, c2: Int, c3: Int, c4: Int) => c1 + c2 + c3 + c4) checkAnswer( df.withColumn("res", sum(col("c1"), col("c2"), col("c3"), col("c4"))).select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn("res", sum1(col("c1"), col("c2"), col("c3"), col("c4"))).select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn("res", callUDF(funcName, col("c1"), col("c2"), col("c3"), col("c4"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) } test("Test for num args : 5", JavaStoredProcExclude) { @@ -796,28 +751,23 @@ trait UDFSuite extends TestData { val df = Seq((1, 2, 3, 4, 5)).toDF(columns) val sum = udf((c1: Int, c2: Int, c3: Int, c4: Int, c5: Int) => c1 + c2 + c3 + c4 + c5) val sum1 = session.udf.registerTemporary((c1: Int, c2: Int, c3: Int, c4: Int, c5: Int) => - c1 + c2 + c3 + c4 + c5 - ) + c1 + c2 + c3 + c4 + c5) val funcName = randomName() session.udf.registerTemporary( funcName, - (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int) => c1 + c2 + c3 + c4 + c5 - ) + (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int) => c1 + c2 + c3 + c4 + c5) checkAnswer( df.withColumn("res", sum(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn("res", sum1(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn("res", callUDF(funcName, col("c1"), col("c2"), col("c3"), col("c4"), col("c5"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) } test("Test for num args : 6", JavaStoredProcExclude) { @@ -828,30 +778,25 @@ trait UDFSuite extends TestData { udf((c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int) => c1 + c2 + c3 + c4 + c5 + c6) val sum1 = session.udf.registerTemporary((c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int) => - c1 + c2 + c3 + c4 + c5 + c6 - ) + c1 + c2 + c3 + c4 + c5 + c6) val funcName = randomName() session.udf.registerTemporary( funcName, - (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int) => c1 + c2 + c3 + c4 + c5 + c6 - ) + (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int) => c1 + c2 + c3 + c4 + c5 + c6) checkAnswer( df.withColumn("res", sum(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn("res", sum1(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"))) .select("res"), - Seq(Row(result)) - ) + Seq(Row(result))) checkAnswer( df.withColumn( "res", - callUDF(funcName, col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6")) - ).select("res"), - Seq(Row(result)) - ) + callUDF(funcName, col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 7", JavaStoredProcExclude) { @@ -859,32 +804,27 @@ trait UDFSuite extends TestData { val columns = (1 to 7).map("c" + _) val df = Seq((1, 2, 3, 4, 5, 6, 7)).toDF(columns) val sum = udf((c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7) val sum1 = session.udf.registerTemporary( (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7) val funcName = randomName() session.udf.registerTemporary( funcName, (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7) checkAnswer( df.withColumn( "res", - sum(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"), col("c7")) - ).select("res"), - Seq(Row(result)) - ) + sum(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"), col("c7"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", - sum1(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"), col("c7")) - ).select("res"), - Seq(Row(result)) - ) + sum1(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"), col("c7"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -896,11 +836,9 @@ trait UDFSuite extends TestData { col("c4"), col("c5"), col("c6"), - col("c7") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c7"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 8", JavaStoredProcExclude) { @@ -908,32 +846,35 @@ trait UDFSuite extends TestData { val columns = (1 to 8).map("c" + _) val df = Seq((1, 2, 3, 4, 5, 6, 7, 8)).toDF(columns) val sum = udf((c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8) val sum1 = session.udf.registerTemporary( (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8) val funcName = randomName() session.udf.registerTemporary( funcName, (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8) checkAnswer( df.withColumn( "res", - sum(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"), col("c7"), col("c8")) - ).select("res"), - Seq(Row(result)) - ) + sum(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"), col("c7"), col("c8"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", - sum1(col("c1"), col("c2"), col("c3"), col("c4"), col("c5"), col("c6"), col("c7"), col("c8")) - ).select("res"), - Seq(Row(result)) - ) + sum1( + col("c1"), + col("c2"), + col("c3"), + col("c4"), + col("c5"), + col("c6"), + col("c7"), + col("c8"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -946,11 +887,9 @@ trait UDFSuite extends TestData { col("c5"), col("c6"), col("c7"), - col("c8") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c8"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 9", JavaStoredProcExclude) { @@ -959,18 +898,15 @@ trait UDFSuite extends TestData { val df = Seq((1, 2, 3, 4, 5, 6, 7, 8, 9)).toDF(columns) val sum = udf((c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int, c9: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9) val sum1 = session.udf.registerTemporary( (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int, c9: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9) val funcName = randomName() session.udf.registerTemporary( funcName, (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int, c9: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9) checkAnswer( df.withColumn( "res", @@ -983,11 +919,9 @@ trait UDFSuite extends TestData { col("c6"), col("c7"), col("c8"), - col("c9") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c9"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1000,11 +934,9 @@ trait UDFSuite extends TestData { col("c6"), col("c7"), col("c8"), - col("c9") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c9"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1018,11 +950,9 @@ trait UDFSuite extends TestData { col("c6"), col("c7"), col("c8"), - col("c9") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c9"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 10", JavaStoredProcExclude) { @@ -1031,18 +961,15 @@ trait UDFSuite extends TestData { val df = Seq((1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).toDF(columns) val sum = udf( (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int, c9: Int, c10: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10) val sum1 = session.udf.registerTemporary( (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int, c9: Int, c10: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10) val funcName = randomName() session.udf.registerTemporary( funcName, (c1: Int, c2: Int, c3: Int, c4: Int, c5: Int, c6: Int, c7: Int, c8: Int, c9: Int, c10: Int) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 - ) + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10) checkAnswer( df.withColumn( "res", @@ -1056,11 +983,9 @@ trait UDFSuite extends TestData { col("c7"), col("c8"), col("c9"), - col("c10") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c10"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1074,11 +999,9 @@ trait UDFSuite extends TestData { col("c7"), col("c8"), col("c9"), - col("c10") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c10"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1093,11 +1016,9 @@ trait UDFSuite extends TestData { col("c7"), col("c8"), col("c9"), - col("c10") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c10"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 11", JavaStoredProcExclude) { @@ -1116,9 +1037,7 @@ trait UDFSuite extends TestData { c8: Int, c9: Int, c10: Int, - c11: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 - ) + c11: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -1131,9 +1050,7 @@ trait UDFSuite extends TestData { c8: Int, c9: Int, c10: Int, - c11: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 - ) + c11: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -1148,9 +1065,7 @@ trait UDFSuite extends TestData { c8: Int, c9: Int, c10: Int, - c11: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 - ) + c11: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11) checkAnswer( df.withColumn( "res", @@ -1165,11 +1080,9 @@ trait UDFSuite extends TestData { col("c8"), col("c9"), col("c10"), - col("c11") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c11"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1184,11 +1097,9 @@ trait UDFSuite extends TestData { col("c8"), col("c9"), col("c10"), - col("c11") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c11"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1204,11 +1115,9 @@ trait UDFSuite extends TestData { col("c8"), col("c9"), col("c10"), - col("c11") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c11"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 12", JavaStoredProcExclude) { @@ -1228,9 +1137,7 @@ trait UDFSuite extends TestData { c9: Int, c10: Int, c11: Int, - c12: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 - ) + c12: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -1244,9 +1151,7 @@ trait UDFSuite extends TestData { c9: Int, c10: Int, c11: Int, - c12: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 - ) + c12: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -1262,9 +1167,7 @@ trait UDFSuite extends TestData { c9: Int, c10: Int, c11: Int, - c12: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 - ) + c12: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12) checkAnswer( df.withColumn( "res", @@ -1280,11 +1183,9 @@ trait UDFSuite extends TestData { col("c9"), col("c10"), col("c11"), - col("c12") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c12"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1300,11 +1201,9 @@ trait UDFSuite extends TestData { col("c9"), col("c10"), col("c11"), - col("c12") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c12"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1321,11 +1220,9 @@ trait UDFSuite extends TestData { col("c9"), col("c10"), col("c11"), - col("c12") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c12"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 13", JavaStoredProcExclude) { @@ -1346,9 +1243,7 @@ trait UDFSuite extends TestData { c10: Int, c11: Int, c12: Int, - c13: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 - ) + c13: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -1363,9 +1258,7 @@ trait UDFSuite extends TestData { c10: Int, c11: Int, c12: Int, - c13: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 - ) + c13: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -1382,9 +1275,7 @@ trait UDFSuite extends TestData { c10: Int, c11: Int, c12: Int, - c13: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 - ) + c13: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13) checkAnswer( df.withColumn( "res", @@ -1401,11 +1292,9 @@ trait UDFSuite extends TestData { col("c10"), col("c11"), col("c12"), - col("c13") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c13"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1422,11 +1311,9 @@ trait UDFSuite extends TestData { col("c10"), col("c11"), col("c12"), - col("c13") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c13"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1444,11 +1331,9 @@ trait UDFSuite extends TestData { col("c10"), col("c11"), col("c12"), - col("c13") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c13"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 14", JavaStoredProcExclude) { @@ -1470,9 +1355,7 @@ trait UDFSuite extends TestData { c11: Int, c12: Int, c13: Int, - c14: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 - ) + c14: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -1488,9 +1371,7 @@ trait UDFSuite extends TestData { c11: Int, c12: Int, c13: Int, - c14: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 - ) + c14: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -1508,9 +1389,7 @@ trait UDFSuite extends TestData { c11: Int, c12: Int, c13: Int, - c14: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 - ) + c14: Int) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14) checkAnswer( df.withColumn( "res", @@ -1528,11 +1407,9 @@ trait UDFSuite extends TestData { col("c11"), col("c12"), col("c13"), - col("c14") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c14"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1550,11 +1427,9 @@ trait UDFSuite extends TestData { col("c11"), col("c12"), col("c13"), - col("c14") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c14"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1573,11 +1448,9 @@ trait UDFSuite extends TestData { col("c11"), col("c12"), col("c13"), - col("c14") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c14"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 15", JavaStoredProcExclude) { @@ -1600,9 +1473,8 @@ trait UDFSuite extends TestData { c12: Int, c13: Int, c14: Int, - c15: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 - ) + c15: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -1619,9 +1491,8 @@ trait UDFSuite extends TestData { c12: Int, c13: Int, c14: Int, - c15: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 - ) + c15: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -1640,9 +1511,8 @@ trait UDFSuite extends TestData { c12: Int, c13: Int, c14: Int, - c15: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 - ) + c15: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15) checkAnswer( df.withColumn( "res", @@ -1661,11 +1531,9 @@ trait UDFSuite extends TestData { col("c12"), col("c13"), col("c14"), - col("c15") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c15"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1684,11 +1552,9 @@ trait UDFSuite extends TestData { col("c12"), col("c13"), col("c14"), - col("c15") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c15"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1708,11 +1574,9 @@ trait UDFSuite extends TestData { col("c12"), col("c13"), col("c14"), - col("c15") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c15"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 16", JavaStoredProcExclude) { @@ -1736,9 +1600,8 @@ trait UDFSuite extends TestData { c13: Int, c14: Int, c15: Int, - c16: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 - ) + c16: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -1756,9 +1619,8 @@ trait UDFSuite extends TestData { c13: Int, c14: Int, c15: Int, - c16: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 - ) + c16: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -1778,9 +1640,8 @@ trait UDFSuite extends TestData { c13: Int, c14: Int, c15: Int, - c16: Int - ) => c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 - ) + c16: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16) checkAnswer( df.withColumn( "res", @@ -1800,11 +1661,9 @@ trait UDFSuite extends TestData { col("c13"), col("c14"), col("c15"), - col("c16") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c16"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1824,11 +1683,9 @@ trait UDFSuite extends TestData { col("c13"), col("c14"), col("c15"), - col("c16") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c16"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1849,11 +1706,9 @@ trait UDFSuite extends TestData { col("c13"), col("c14"), col("c15"), - col("c16") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c16"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 17", JavaStoredProcExclude) { @@ -1878,10 +1733,8 @@ trait UDFSuite extends TestData { c14: Int, c15: Int, c16: Int, - c17: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 - ) + c17: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -1900,10 +1753,8 @@ trait UDFSuite extends TestData { c14: Int, c15: Int, c16: Int, - c17: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 - ) + c17: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -1924,10 +1775,8 @@ trait UDFSuite extends TestData { c14: Int, c15: Int, c16: Int, - c17: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 - ) + c17: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17) checkAnswer( df.withColumn( "res", @@ -1948,11 +1797,9 @@ trait UDFSuite extends TestData { col("c14"), col("c15"), col("c16"), - col("c17") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c17"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1973,11 +1820,9 @@ trait UDFSuite extends TestData { col("c14"), col("c15"), col("c16"), - col("c17") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c17"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -1999,11 +1844,9 @@ trait UDFSuite extends TestData { col("c14"), col("c15"), col("c16"), - col("c17") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c17"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 18", JavaStoredProcExclude) { @@ -2029,10 +1872,8 @@ trait UDFSuite extends TestData { c15: Int, c16: Int, c17: Int, - c18: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 - ) + c18: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -2052,10 +1893,8 @@ trait UDFSuite extends TestData { c15: Int, c16: Int, c17: Int, - c18: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 - ) + c18: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -2077,10 +1916,8 @@ trait UDFSuite extends TestData { c15: Int, c16: Int, c17: Int, - c18: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 - ) + c18: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18) checkAnswer( df.withColumn( "res", @@ -2102,11 +1939,9 @@ trait UDFSuite extends TestData { col("c15"), col("c16"), col("c17"), - col("c18") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c18"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2128,11 +1963,9 @@ trait UDFSuite extends TestData { col("c15"), col("c16"), col("c17"), - col("c18") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c18"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2155,11 +1988,9 @@ trait UDFSuite extends TestData { col("c15"), col("c16"), col("c17"), - col("c18") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c18"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 19", JavaStoredProcExclude) { @@ -2187,10 +2018,8 @@ trait UDFSuite extends TestData { c16: Int, c17: Int, c18: Int, - c19: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 - ) + c19: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -2211,10 +2040,8 @@ trait UDFSuite extends TestData { c16: Int, c17: Int, c18: Int, - c19: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 - ) + c19: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -2237,10 +2064,8 @@ trait UDFSuite extends TestData { c16: Int, c17: Int, c18: Int, - c19: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 - ) + c19: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19) checkAnswer( df.withColumn( "res", @@ -2263,11 +2088,9 @@ trait UDFSuite extends TestData { col("c16"), col("c17"), col("c18"), - col("c19") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c19"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2290,11 +2113,9 @@ trait UDFSuite extends TestData { col("c16"), col("c17"), col("c18"), - col("c19") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c19"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2318,11 +2139,9 @@ trait UDFSuite extends TestData { col("c16"), col("c17"), col("c18"), - col("c19") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c19"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 20", JavaStoredProcExclude) { @@ -2351,10 +2170,8 @@ trait UDFSuite extends TestData { c17: Int, c18: Int, c19: Int, - c20: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 - ) + c20: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -2376,10 +2193,8 @@ trait UDFSuite extends TestData { c17: Int, c18: Int, c19: Int, - c20: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 - ) + c20: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -2403,10 +2218,8 @@ trait UDFSuite extends TestData { c17: Int, c18: Int, c19: Int, - c20: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 - ) + c20: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20) checkAnswer( df.withColumn( "res", @@ -2430,11 +2243,9 @@ trait UDFSuite extends TestData { col("c17"), col("c18"), col("c19"), - col("c20") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c20"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2458,11 +2269,9 @@ trait UDFSuite extends TestData { col("c17"), col("c18"), col("c19"), - col("c20") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c20"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2487,11 +2296,9 @@ trait UDFSuite extends TestData { col("c17"), col("c18"), col("c19"), - col("c20") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c20"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 21", JavaStoredProcExclude) { @@ -2521,10 +2328,8 @@ trait UDFSuite extends TestData { c18: Int, c19: Int, c20: Int, - c21: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 - ) + c21: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -2547,10 +2352,8 @@ trait UDFSuite extends TestData { c18: Int, c19: Int, c20: Int, - c21: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 - ) + c21: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -2575,10 +2378,8 @@ trait UDFSuite extends TestData { c18: Int, c19: Int, c20: Int, - c21: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 - ) + c21: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21) checkAnswer( df.withColumn( "res", @@ -2603,11 +2404,9 @@ trait UDFSuite extends TestData { col("c18"), col("c19"), col("c20"), - col("c21") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c21"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2632,11 +2431,9 @@ trait UDFSuite extends TestData { col("c18"), col("c19"), col("c20"), - col("c21") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c21"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2662,11 +2459,9 @@ trait UDFSuite extends TestData { col("c18"), col("c19"), col("c20"), - col("c21") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c21"))) + .select("res"), + Seq(Row(result))) } test("Test for num args : 22", JavaStoredProcExclude) { @@ -2697,10 +2492,8 @@ trait UDFSuite extends TestData { c19: Int, c20: Int, c21: Int, - c22: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 + c22 - ) + c22: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 + c22) val sum1 = session.udf.registerTemporary( ( c1: Int, @@ -2724,10 +2517,8 @@ trait UDFSuite extends TestData { c19: Int, c20: Int, c21: Int, - c22: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 + c22 - ) + c22: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 + c22) val funcName = randomName() session.udf.registerTemporary( funcName, @@ -2753,10 +2544,8 @@ trait UDFSuite extends TestData { c19: Int, c20: Int, c21: Int, - c22: Int - ) => - c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 + c22 - ) + c22: Int) => + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11 + c12 + c13 + c14 + c15 + c16 + c17 + c18 + c19 + c20 + c21 + c22) checkAnswer( df.withColumn( "res", @@ -2782,11 +2571,9 @@ trait UDFSuite extends TestData { col("c19"), col("c20"), col("c21"), - col("c22") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c22"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2812,11 +2599,9 @@ trait UDFSuite extends TestData { col("c19"), col("c20"), col("c21"), - col("c22") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c22"))) + .select("res"), + Seq(Row(result))) checkAnswer( df.withColumn( "res", @@ -2843,11 +2628,9 @@ trait UDFSuite extends TestData { col("c19"), col("c20"), col("c21"), - col("c22") - ) - ).select("res"), - Seq(Row(result)) - ) + col("c22"))) + .select("res"), + Seq(Row(result))) } // system$cancel_all_queries not allowed from owner mode procs @@ -2977,8 +2760,7 @@ trait UDFSuite extends TestData { ||0.5399685289472378 | ||1.0 | |------------------------------ - |""".stripMargin - ) + |""".stripMargin) } test("register temp UDF doesn't commit open transaction") { diff --git a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala index 95d6e74d..2c616723 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala @@ -68,8 +68,7 @@ class UDTFSuite extends TestData { """root | |--WORD: String (nullable = true) | |--COUNT: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df1, Seq(Row("w1", 1), Row("w2", 2), Row("w3", 3))) // Call the UDTF with funcName and named parameters, result should be the same @@ -81,14 +80,12 @@ class UDTFSuite extends TestData { """root | |--COUNT: Long (nullable = true) | |--WORD: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df2, Seq(Row(1, "w1"), Row(2, "w2"), Row(3, "w3"))) // Use UDTF with table join val df3 = session.sql( - s"select * from $wordCountTableName, table($funcName(c1) over (partition by c2))" - ) + s"select * from $wordCountTableName, table($funcName(c1) over (partition by c2))") assert( getSchemaString(df3.schema) == """root @@ -96,16 +93,13 @@ class UDTFSuite extends TestData { | |--C2: String (nullable = true) | |--WORD: String (nullable = true) | |--COUNT: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df3, Seq( Row("w1 w2", "g1", "w2", 1), Row("w1 w2", "g1", "w1", 1), - Row("w1 w1 w1", "g2", "w1", 3) - ) - ) + Row("w1 w1 w1", "g2", "w1", 3))) } finally { runQuery(s"drop function if exists $funcName(STRING)", session) } @@ -147,31 +141,26 @@ class UDTFSuite extends TestData { """root | |--WORD: String (nullable = true) | |--COUNT: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df1, - Seq(Row("w3", 6), Row("w2", 4), Row("w1", 2), Row("w3", 6), Row("w2", 4), Row("w1", 2)) - ) + Seq(Row("w3", 6), Row("w2", 4), Row("w1", 2), Row("w3", 6), Row("w2", 4), Row("w1", 2))) // Call the UDTF with funcName and named parameters, result should be the same val df2 = session .tableFunction( TableFunction(funcName), - Map("arg1" -> lit("w1 w2 w2 w3 w3 w3"), "arg2" -> lit("w1 w2 w2 w3 w3 w3")) - ) + Map("arg1" -> lit("w1 w2 w2 w3 w3 w3"), "arg2" -> lit("w1 w2 w2 w3 w3 w3"))) .select("count", "word") assert( getSchemaString(df2.schema) == """root | |--COUNT: Long (nullable = true) | |--WORD: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df2, - Seq(Row(6, "w3"), Row(4, "w2"), Row(2, "w1"), Row(6, "w3"), Row(4, "w2"), Row(2, "w1")) - ) + Seq(Row(6, "w3"), Row(4, "w2"), Row(2, "w1"), Row(6, "w3"), Row(4, "w2"), Row(2, "w1"))) // scalastyle:off // Use UDTF with table join @@ -187,8 +176,7 @@ class UDTFSuite extends TestData { // scalastyle:on val df3 = session.sql( s"select * from $wordCountTableName, " + - s"table($funcName(c1, c2) over (partition by 1))" - ) + s"table($funcName(c1, c2) over (partition by 1))") assert( getSchemaString(df3.schema) == """root @@ -196,8 +184,7 @@ class UDTFSuite extends TestData { | |--C2: String (nullable = true) | |--WORD: String (nullable = true) | |--COUNT: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df3, Seq( @@ -209,14 +196,11 @@ class UDTFSuite extends TestData { Row(null, null, "g2", 1), Row(null, null, "w2", 1), Row(null, null, "g1", 1), - Row(null, null, "w1", 4) - ) - ) + Row(null, null, "w1", 4))) // Use UDTF with table function + over partition val df4 = session.sql( - s"select * from $wordCountTableName, table($funcName(c1, c2) over (partition by c2))" - ) + s"select * from $wordCountTableName, table($funcName(c1, c2) over (partition by c2))") checkAnswer( df4, Seq( @@ -229,9 +213,7 @@ class UDTFSuite extends TestData { Row("w1 w1 w1", "g2", "g2", 1), Row("w1 w1 w1", "g2", "w1", 3), Row(null, "g2", "g2", 1), - Row(null, "g2", "w1", 3) - ) - ) + Row(null, "g2", "w1", 3))) } finally { runQuery(s"drop function if exists $funcName(VARCHAR,VARCHAR)", session) } @@ -257,8 +239,7 @@ class UDTFSuite extends TestData { getSchemaString(df1.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df1, Seq(Row(10), Row(11), Row(12), Row(13), Row(14))) val df2 = session.tableFunction(TableFunction(funcName), lit(20), lit(5)) @@ -266,8 +247,7 @@ class UDTFSuite extends TestData { getSchemaString(df2.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df2, Seq(Row(20), Row(21), Row(22), Row(23), Row(24))) val df3 = session @@ -276,8 +256,7 @@ class UDTFSuite extends TestData { getSchemaString(df3.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34))) } finally { runQuery(s"drop function if exists $funcName(NUMBER, NUMBER)", session) @@ -334,8 +313,7 @@ class UDTFSuite extends TestData { StructType( StructField("word", StringType), StructField("count", IntegerType), - StructField("size", IntegerType) - ) + StructField("size", IntegerType)) } val largeUdTF = new LargeUDTF() @@ -351,8 +329,7 @@ class UDTFSuite extends TestData { | |--WORD: String (nullable = true) | |--COUNT: Long (nullable = true) | |--SIZE: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df1, Seq(Row("w1", 1, dataLength), Row("w2", 1, dataLength))) } finally { runQuery(s"drop function if exists $funcName(STRING)", session) @@ -394,8 +371,7 @@ class UDTFSuite extends TestData { getSchemaString(df1.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df1, Seq(Row(10), Row(11), Row(20), Row(21), Row(22), Row(23))) val df2 = session.tableFunction(TableFunction(funcName), sourceDF("b"), sourceDF("c")) @@ -403,8 +379,7 @@ class UDTFSuite extends TestData { getSchemaString(df2.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df2, Seq(Row(100), Row(101), Row(200), Row(201), Row(202), Row(203))) // Check table function with df column arguments as Map @@ -415,8 +390,7 @@ class UDTFSuite extends TestData { getSchemaString(df3.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df3, Seq(Row(30), Row(31), Row(32), Row(33), Row(34))) // Check table function with nested functions on df column @@ -425,8 +399,7 @@ class UDTFSuite extends TestData { getSchemaString(df4.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df4, Seq(Row(10), Row(11), Row(20), Row(21))) // Check result df column filtering with duplicate column names @@ -440,8 +413,7 @@ class UDTFSuite extends TestData { getSchemaString(df.schema) == """root | |--C1: Long (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df, Seq(Row(30), Row(31), Row(32), Row(33), Row(34))) } } finally { @@ -666,8 +638,7 @@ class UDTFSuite extends TestData { lit(5), lit(6), lit(7), - lit(8) - ) + lit(8)) val result = (1 to 8).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -683,8 +654,7 @@ class UDTFSuite extends TestData { a6: Int, a7: Int, a8: Int, - a9: Int - ): Iterable[Row] = { + a9: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9).sum Seq(Row(sum), Row(sum)) } @@ -703,8 +673,7 @@ class UDTFSuite extends TestData { lit(6), lit(7), lit(8), - lit(9) - ) + lit(9)) val result = (1 to 9).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -721,8 +690,7 @@ class UDTFSuite extends TestData { a7: Int, a8: Int, a9: Int, - a10: Int - ): Iterable[Row] = { + a10: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10).sum Seq(Row(sum), Row(sum)) } @@ -742,8 +710,7 @@ class UDTFSuite extends TestData { lit(7), lit(8), lit(9), - lit(10) - ) + lit(10)) val result = (1 to 10).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -762,8 +729,7 @@ class UDTFSuite extends TestData { a8: Int, a9: Int, a10: Int, - a11: Int - ): Iterable[Row] = { + a11: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11).sum Seq(Row(sum), Row(sum)) } @@ -785,8 +751,7 @@ class UDTFSuite extends TestData { lit(8), lit(9), lit(10), - lit(11) - ) + lit(11)) val result = (1 to 11).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -806,8 +771,7 @@ class UDTFSuite extends TestData { a9: Int, a10: Int, a11: Int, - a12: Int - ): Iterable[Row] = { + a12: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12).sum Seq(Row(sum), Row(sum)) } @@ -830,8 +794,7 @@ class UDTFSuite extends TestData { lit(9), lit(10), lit(11), - lit(12) - ) + lit(12)) val result = (1 to 12).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -852,8 +815,7 @@ class UDTFSuite extends TestData { a10: Int, a11: Int, a12: Int, - a13: Int - ): Iterable[Row] = { + a13: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13).sum Seq(Row(sum), Row(sum)) } @@ -877,8 +839,7 @@ class UDTFSuite extends TestData { lit(10), lit(11), lit(12), - lit(13) - ) + lit(13)) val result = (1 to 13).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -901,8 +862,7 @@ class UDTFSuite extends TestData { a11: Int, a12: Int, a13: Int, - a14: Int - ): Iterable[Row] = { + a14: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14).sum Seq(Row(sum), Row(sum)) } @@ -927,8 +887,7 @@ class UDTFSuite extends TestData { lit(11), lit(12), lit(13), - lit(14) - ) + lit(14)) val result = (1 to 14).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -951,8 +910,7 @@ class UDTFSuite extends TestData { a12: Int, a13: Int, a14: Int, - a15: Int - ): Iterable[Row] = { + a15: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15).sum Seq(Row(sum), Row(sum)) } @@ -977,8 +935,7 @@ class UDTFSuite extends TestData { lit(12), lit(13), lit(14), - lit(15) - ) + lit(15)) val result = (1 to 15).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -1001,8 +958,7 @@ class UDTFSuite extends TestData { Int, Int, Int, - Int - ] { + Int] { override def process( a1: Int, a2: Int, @@ -1019,8 +975,7 @@ class UDTFSuite extends TestData { a13: Int, a14: Int, a15: Int, - a16: Int - ): Iterable[Row] = { + a16: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16).sum Seq(Row(sum), Row(sum)) } @@ -1046,8 +1001,7 @@ class UDTFSuite extends TestData { lit(13), lit(14), lit(15), - lit(16) - ) + lit(16)) val result = (1 to 16).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -1071,8 +1025,7 @@ class UDTFSuite extends TestData { Int, Int, Int, - Int - ] { + Int] { override def process( a1: Int, a2: Int, @@ -1090,8 +1043,7 @@ class UDTFSuite extends TestData { a14: Int, a15: Int, a16: Int, - a17: Int - ): Iterable[Row] = { + a17: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17).sum Seq(Row(sum), Row(sum)) @@ -1119,8 +1071,7 @@ class UDTFSuite extends TestData { lit(14), lit(15), lit(16), - lit(17) - ) + lit(17)) val result = (1 to 17).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -1145,8 +1096,7 @@ class UDTFSuite extends TestData { Int, Int, Int, - Int - ] { + Int] { override def process( a1: Int, a2: Int, @@ -1165,8 +1115,7 @@ class UDTFSuite extends TestData { a15: Int, a16: Int, a17: Int, - a18: Int - ): Iterable[Row] = { + a18: Int): Iterable[Row] = { val sum = Seq(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18).sum Seq(Row(sum), Row(sum)) @@ -1195,8 +1144,7 @@ class UDTFSuite extends TestData { lit(15), lit(16), lit(17), - lit(18) - ) + lit(18)) val result = (1 to 18).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -1222,8 +1170,7 @@ class UDTFSuite extends TestData { Int, Int, Int, - Int - ] { + Int] { override def process( a1: Int, a2: Int, @@ -1243,8 +1190,7 @@ class UDTFSuite extends TestData { a16: Int, a17: Int, a18: Int, - a19: Int - ): Iterable[Row] = { + a19: Int): Iterable[Row] = { val sum = Seq( a1, a2, @@ -1264,8 +1210,7 @@ class UDTFSuite extends TestData { a16, a17, a18, - a19 - ).sum + a19).sum Seq(Row(sum), Row(sum)) } override def endPartition(): Iterable[Row] = Seq.empty @@ -1293,8 +1238,7 @@ class UDTFSuite extends TestData { lit(16), lit(17), lit(18), - lit(19) - ) + lit(19)) val result = (1 to 19).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -1321,8 +1265,7 @@ class UDTFSuite extends TestData { Int, Int, Int, - Int - ] { + Int] { override def process( a1: Int, a2: Int, @@ -1343,8 +1286,7 @@ class UDTFSuite extends TestData { a17: Int, a18: Int, a19: Int, - a20: Int - ): Iterable[Row] = { + a20: Int): Iterable[Row] = { val sum = Seq( a1, a2, @@ -1365,8 +1307,7 @@ class UDTFSuite extends TestData { a17, a18, a19, - a20 - ).sum + a20).sum Seq(Row(sum), Row(sum)) } override def endPartition(): Iterable[Row] = Seq.empty @@ -1395,8 +1336,7 @@ class UDTFSuite extends TestData { lit(17), lit(18), lit(19), - lit(20) - ) + lit(20)) val result = (1 to 20).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -1424,8 +1364,7 @@ class UDTFSuite extends TestData { Int, Int, Int, - Int - ] { + Int] { override def process( a1: Int, a2: Int, @@ -1447,8 +1386,7 @@ class UDTFSuite extends TestData { a18: Int, a19: Int, a20: Int, - a21: Int - ): Iterable[Row] = { + a21: Int): Iterable[Row] = { val sum = Seq( a1, a2, @@ -1470,8 +1408,7 @@ class UDTFSuite extends TestData { a18, a19, a20, - a21 - ).sum + a21).sum Seq(Row(sum), Row(sum)) } override def endPartition(): Iterable[Row] = Seq.empty @@ -1501,8 +1438,7 @@ class UDTFSuite extends TestData { lit(18), lit(19), lit(20), - lit(21) - ) + lit(21)) val result = (1 to 21).sum checkAnswer(df1, Seq(Row(result), Row(result))) } @@ -1531,8 +1467,7 @@ class UDTFSuite extends TestData { Int, Int, Int, - Int - ] { + Int] { override def process( a1: Int, a2: Int, @@ -1555,8 +1490,7 @@ class UDTFSuite extends TestData { a19: Int, a20: Int, a21: Int, - a22: Int - ): Iterable[Row] = { + a22: Int): Iterable[Row] = { val sum = Seq( a1, a2, @@ -1579,8 +1513,7 @@ class UDTFSuite extends TestData { a19, a20, a21, - a22 - ).sum + a22).sum Seq(Row(sum), Row(sum)) } override def endPartition(): Iterable[Row] = Seq.empty @@ -1611,33 +1544,34 @@ class UDTFSuite extends TestData { lit(19), lit(20), lit(21), - lit(22) - ) + lit(22)) val result = (1 to 22).sum checkAnswer(df1, Seq(Row(result), Row(result))) } test("test input Type: Option[_]") { class UDTFInputOptionTypes - extends UDTF6[Option[Short], Option[Int], Option[Long], Option[Float], Option[ - Double - ], Option[Boolean]] { + extends UDTF6[ + Option[Short], + Option[Int], + Option[Long], + Option[Float], + Option[Double], + Option[Boolean]] { override def process( si2: Option[Short], i2: Option[Int], li2: Option[Long], f2: Option[Float], d2: Option[Double], - b2: Option[Boolean] - ): Iterable[Row] = { + b2: Option[Boolean]): Iterable[Row] = { val row = Row( si2.map(_.toString).orNull, i2.map(_.toString).orNull, li2.map(_.toString).orNull, f2.map(_.toString).orNull, d2.map(_.toString).orNull, - b2.map(_.toString).orNull - ) + b2.map(_.toString).orNull) Seq(row) } override def endPartition(): Iterable[Row] = Seq.empty @@ -1648,16 +1582,14 @@ class UDTFSuite extends TestData { StructField("bi2_str", StringType), StructField("f2_str", StringType), StructField("d2_str", StringType), - StructField("b2_str", StringType) - ) + StructField("b2_str", StringType)) } createTable(tableName, "si2 smallint, i2 int, bi2 bigint, f2 float, d2 double, b2 boolean") runQuery( s"insert into $tableName values (1, 2, 3, 4.4, 8.8, true)," + s" (null, null, null, null, null, null)", - session - ) + session) val tableFunction = session.udtf.registerTemporary(new UDTFInputOptionTypes) val df1 = session .table(tableName) @@ -1677,15 +1609,12 @@ class UDTFSuite extends TestData { | |--F2_STR: String (nullable = true) | |--D2_STR: String (nullable = true) | |--B2_STR: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df1, Seq( Row(1, 2, 3, 4.4, 8.8, true, "1", "2", "3", "4.4", "8.8", "true"), - Row(null, null, null, null, null, null, null, null, null, null, null, null) - ) - ) + Row(null, null, null, null, null, null, null, null, null, null, null, null))) } test("test input Type: basic types") { @@ -1700,8 +1629,7 @@ class UDTFSuite extends TestData { java.math.BigDecimal, String, java.lang.String, - Array[Byte] - ] { + Array[Byte]] { override def process( si1: Short, i1: Int, @@ -1712,8 +1640,7 @@ class UDTFSuite extends TestData { decimal: java.math.BigDecimal, str: String, str2: java.lang.String, - bytes: Array[Byte] - ): Iterable[Row] = { + bytes: Array[Byte]): Iterable[Row] = { val row = Row( si1.toString, i1.toString, @@ -1724,8 +1651,7 @@ class UDTFSuite extends TestData { decimal.toString, str, str2, - bytes.map { _.toChar }.mkString - ) + bytes.map { _.toChar }.mkString) Seq(row) } override def endPartition(): Iterable[Row] = Seq.empty @@ -1740,8 +1666,7 @@ class UDTFSuite extends TestData { StructField("decimal_str", StringType), StructField("str1_str", StringType), StructField("str2_str", StringType), - StructField("bytes_str", StringType) - ) + StructField("bytes_str", StringType)) } val tableFunction = session.udtf.registerTemporary(new UDTFInputBasicTypes) @@ -1757,8 +1682,7 @@ class UDTFSuite extends TestData { lit(decimal).cast(DecimalType(38, 18)), lit("scala"), lit(new java.lang.String("java")), - lit("bytes".getBytes()) - ) + lit("bytes".getBytes())) assert( getSchemaString(df1.schema) == """root @@ -1772,14 +1696,21 @@ class UDTFSuite extends TestData { | |--STR1_STR: String (nullable = true) | |--STR2_STR: String (nullable = true) | |--BYTES_STR: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df1, Seq( - Row("1", "2", "3", "4.4", "8.8", "true", "123.456000000000000000", "scala", "java", "bytes") - ) - ) + Row( + "1", + "2", + "3", + "4.4", + "8.8", + "true", + "123.456000000000000000", + "scala", + "java", + "bytes"))) } test("test input Type: Date/Time/TimeStamp", JavaStoredProcExclude) { @@ -1801,8 +1732,7 @@ class UDTFSuite extends TestData { StructType( StructField("date_str", StringType), StructField("time_str", StringType), - StructField("timestamp_str", StringType) - ) + StructField("timestamp_str", StringType)) } createTable(tableName, "date Date, time Time, ts timestamp_ntz") @@ -1810,8 +1740,7 @@ class UDTFSuite extends TestData { s"insert into $tableName values " + s"('2022-01-25', '00:00:00', '2022-01-25 00:00:00.000')," + s"('2022-01-25', '12:13:14', '2022-01-25 12:13:14.123')", - session - ) + session) val tableFunction = session.udtf.registerTemporary(new UDTFInputTimestampTypes) val df1 = session.table(tableName).join(tableFunction, col("date"), col("time"), col("ts")) @@ -1824,8 +1753,7 @@ class UDTFSuite extends TestData { | |--DATE_STR: String (nullable = true) | |--TIME_STR: String (nullable = true) | |--TIMESTAMP_STR: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df1, Seq( @@ -1835,18 +1763,14 @@ class UDTFSuite extends TestData { Timestamp.valueOf("2022-01-25 00:00:00.0"), "2022-01-25", "00:00:00", - "2022-01-25 00:00:00.0" - ), + "2022-01-25 00:00:00.0"), Row( Date.valueOf("2022-01-25"), Time.valueOf("12:13:14"), Timestamp.valueOf("2022-01-25 12:13:14.123"), "2022-01-25", "12:13:14", - "2022-01-25 12:13:14.123" - ) - ) - ) + "2022-01-25 12:13:14.123"))) } finally { TimeZone.setDefault(oldTimeZone) session.sql(s"alter session set TIMEZONE = '$sfTimezone'").collect() @@ -1859,15 +1783,13 @@ class UDTFSuite extends TestData { Array[String], Array[Variant], mutable.Map[String, String], - mutable.Map[String, Variant] - ] { + mutable.Map[String, Variant]] { override def process( v: Variant, a1: Array[String], a2: Array[Variant], m1: mutable.Map[String, String], - m2: mutable.Map[String, Variant] - ): Iterable[Row] = { + m2: mutable.Map[String, Variant]): Iterable[Row] = { val (r1, r2) = (a1.mkString("[", ",", "]"), a2.map(_.asString()).mkString("[", ",", "]")) val r3 = m1 .map { x => @@ -1889,8 +1811,7 @@ class UDTFSuite extends TestData { StructField("a1_str", StringType), StructField("a2_str", StringType), StructField("m1_str", StringType), - StructField("m2_str", StringType) - ) + StructField("m2_str", StringType)) } val tableFunction = session.udtf.registerTemporary(udtf) createTable(tableName, "v variant, a1 array, a2 array, m1 object, m2 object") @@ -1898,8 +1819,7 @@ class UDTFSuite extends TestData { s"insert into $tableName " + s"select to_variant('v1'), array_construct('a1', 'a1'), array_construct('a2', 'a2'), " + s"object_construct('m1', 'one'), object_construct('m2', 'two')", - session - ) + session) val df1 = session .table(tableName) @@ -1913,8 +1833,7 @@ class UDTFSuite extends TestData { | |--A2_STR: String (nullable = true) | |--M1_STR: String (nullable = true) | |--M2_STR: String (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer(df1, Seq(Row("v1", "[a1,a1]", "[a2,a2]", "{(m1 -> one)}", "{(m2 -> two)}"))) } @@ -1964,8 +1883,7 @@ class UDTFSuite extends TestData { floatValue.toDouble, java.math.BigDecimal.valueOf(floatValue).setScale(3, RoundingMode.HALF_DOWN), data, - data.getBytes() - ) + data.getBytes()) Seq(row, row) } override def endPartition(): Iterable[Row] = Seq.empty @@ -1979,8 +1897,7 @@ class UDTFSuite extends TestData { StructField("double", DoubleType), StructField("decimal", DecimalType(10, 3)), StructField("string", StringType), - StructField("binary", BinaryType) - ) + StructField("binary", BinaryType)) } val tableFunction = session.udtf.registerTemporary(new ReturnBasicTypes()) @@ -2002,8 +1919,7 @@ class UDTFSuite extends TestData { | |--DECIMAL: Decimal(10, 3) (nullable = true) | |--STRING: String (nullable = true) | |--BINARY: Binary (nullable = true) - |""".stripMargin - ) + |""".stripMargin) val b1 = "-128".getBytes() checkAnswer( df1, @@ -2011,9 +1927,7 @@ class UDTFSuite extends TestData { Row("-128", false, -128, -128, -128, -128.128, -128.128, -128.128, "-128", b1), Row("-128", false, -128, -128, -128, -128.128, -128.128, -128.128, "-128", b1), Row("128", true, 128, 128, 128, 128.128, 128.128, 128.128, "128", "128".getBytes()), - Row("128", true, 128, 128, 128, 128.128, 128.128, 128.128, "128", "128".getBytes()) - ) - ) + Row("128", true, 128, 128, 128, 128.128, 128.128, 128.128, "128", "128".getBytes()))) } test("test output type: Time, Date, Timestamp", JavaStoredProcExclude) { @@ -2037,8 +1951,7 @@ class UDTFSuite extends TestData { StructType( StructField("time", TimeType), StructField("date", DateType), - StructField("timestamp", TimestampType) - ) + StructField("timestamp", TimestampType)) } val tableFunction = session.udtf.registerTemporary(new ReturnTimestampTypes3) @@ -2055,15 +1968,12 @@ class UDTFSuite extends TestData { | |--TIME: Time (nullable = true) | |--DATE: Date (nullable = true) | |--TIMESTAMP: Timestamp (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df1, Seq( Row(ts, Time.valueOf(time), Date.valueOf(date), Timestamp.valueOf(ts)), - Row(ts, Time.valueOf(time), Date.valueOf(date), Timestamp.valueOf(ts)) - ) - ) + Row(ts, Time.valueOf(time), Date.valueOf(date), Timestamp.valueOf(ts)))) } finally { TimeZone.setDefault(oldTimeZone) session.sql(s"alter session set TIMEZONE = '$sfTimezone'").collect() @@ -2072,8 +1982,7 @@ class UDTFSuite extends TestData { test( "test output type: VariantType/ArrayType(StringType|VariantType)/" + - "MapType(StringType, StringType|VariantType)" - ) { + "MapType(StringType, StringType|VariantType)") { class ReturnComplexTypes extends UDTF1[String] { override def process(data: String): Iterable[Row] = { val arr = data.split(" ") @@ -2082,8 +1991,7 @@ class UDTFSuite extends TestData { val variantMap = Map(arr(0) -> new Variant(arr(1))) Seq( Row(data, new Variant(data), arr, arr.map(new Variant(_)), stringMap, variantMap), - Row(data, new Variant(data), seq, seq.map(new Variant(_)), stringMap, variantMap) - ) + Row(data, new Variant(data), seq, seq.map(new Variant(_)), stringMap, variantMap)) } override def endPartition(): Iterable[Row] = Seq.empty override def outputSchema(): StructType = @@ -2093,8 +2001,7 @@ class UDTFSuite extends TestData { StructField("string_array", ArrayType(StringType)), StructField("variant_array", ArrayType(VariantType)), StructField("string_map", MapType(StringType, StringType)), - StructField("variant_map", MapType(StringType, VariantType)) - ) + StructField("variant_map", MapType(StringType, VariantType))) } val tableFunction = session.udtf.registerTemporary(new ReturnComplexTypes()) @@ -2108,8 +2015,7 @@ class UDTFSuite extends TestData { | |--VARIANT_ARRAY: Array (nullable = true) | |--STRING_MAP: Map (nullable = true) | |--VARIANT_MAP: Map (nullable = true) - |""".stripMargin - ) + |""".stripMargin) checkAnswer( df1, Seq( @@ -2119,18 +2025,14 @@ class UDTFSuite extends TestData { "[\n \"v1\",\n \"v2\"\n]", "[\n \"v1\",\n \"v2\"\n]", "{\n \"v1\": \"v2\"\n}", - "{\n \"v1\": \"v2\"\n}" - ), + "{\n \"v1\": \"v2\"\n}"), Row( "v1 v2", "\"v1 v2\"", "[\n \"v1\",\n \"v2\"\n]", "[\n \"v1\",\n \"v2\"\n]", "{\n \"v1\": \"v2\"\n}", - "{\n \"v1\": \"v2\"\n}" - ) - ) - ) + "{\n \"v1\": \"v2\"\n}"))) } test("use UDF and UDTF in one session", JavaStoredProcExclude) { @@ -2197,30 +2099,25 @@ class UDTFSuite extends TestData { val tf = session.udtf.registerTemporary(TableFunc1) checkAnswer( df.join(tf, Seq(df("b")), Seq(df("a")), Seq(df("b"))), - Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)")) - ) + Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) checkAnswer( df.join(tf, Seq(df("b")), Seq(df("a")), Seq.empty), - Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)")) - ) + Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) df.join(tf, Seq(df("b")), Seq.empty, Seq(df("b"))).show() df.join(tf, Seq(df("b")), Seq.empty, Seq.empty).show() checkAnswer( df.join(tf, Map("arg1" -> df("b")), Seq(df("a")), Seq(df("b"))), - Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)")) - ) + Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) checkAnswer( df.join(tf(Map("arg1" -> df("b"))), Seq(df("a")), Seq(df("b"))), - Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)")) - ) + Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) checkAnswer( df.join(tf, Map("arg1" -> df("b")), Seq(df("a")), Seq.empty), - Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)")) - ) + Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) df.join(tf, Map("arg1" -> df("b")), Seq.empty, Seq(df("b"))).show() df.join(tf, Map("arg1" -> df("b")), Seq.empty, Seq.empty).show() } diff --git a/src/test/scala/com/snowflake/snowpark_test/UdxOpenTelemetrySuite.scala b/src/test/scala/com/snowflake/snowpark_test/UdxOpenTelemetrySuite.scala index e6cdde35..8f21b2fd 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UdxOpenTelemetrySuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UdxOpenTelemetrySuite.scala @@ -85,8 +85,7 @@ class UdxOpenTelemetrySuite extends OpenTelemetryEnabled { className: String, funcName: String, execName: String, - execFilePath: String - ): Unit = { + execFilePath: String): Unit = { val stack = Thread.currentThread().getStackTrace val file = stack(2) // this file checkSpan( @@ -96,16 +95,14 @@ class UdxOpenTelemetrySuite extends OpenTelemetryEnabled { file.getLineNumber - 1, execName, "SnowUDF.compute", - execFilePath - ) + execFilePath) } def checkUdtfSpan( className: String, funcName: String, execName: String, - execFilePath: String - ): Unit = { + execFilePath: String): Unit = { val stack = Thread.currentThread().getStackTrace val file = stack(2) // this file checkSpan( @@ -115,7 +112,6 @@ class UdxOpenTelemetrySuite extends OpenTelemetryEnabled { file.getLineNumber - 1, execName, "SnowparkGeneratedUDTF", - execFilePath - ) + execFilePath) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala index 5e096e05..2360866e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UpdatableSuite.scala @@ -32,14 +32,12 @@ class UpdatableSuite extends TestData { s"insert into $semiStructuredTable select parse_json(a), parse_json(b), " + s"parse_json(a), to_geography(c) from values('[1,2]', '{a:1}', 'POINT(-122.35 37.55)')," + s"('[1,2,3]', '{b:2}', 'POINT(-12 37)') as T(a,b,c)", - session - ) + session) createTable(timeTable, "time time") runQuery( s"insert into $timeTable select to_time(a) from values('09:15:29')," + s"('09:15:29.99999999') as T(a)", - session - ) + session) } override def afterAll: Unit = { @@ -100,8 +98,7 @@ class UpdatableSuite extends TestData { checkAnswer( t2, Seq(Row(0, "A"), Row(0, "B"), Row(0, "C"), Row(4, "D"), Row(5, "E"), Row(6, "F")), - sort = false - ) + sort = false) testData2.write.mode(SaveMode.Overwrite).saveAsTable(tableName) upperCaseData.write.mode(SaveMode.Overwrite).saveAsTable(tableName2) @@ -109,8 +106,7 @@ class UpdatableSuite extends TestData { checkAnswer( t2, Seq(Row(0, "A"), Row(0, "B"), Row(0, "C"), Row(4, "D"), Row(5, "E"), Row(6, "F")), - sort = false - ) + sort = false) upperCaseData.write.mode(SaveMode.Overwrite).saveAsTable(tableName2) import session.implicits._ @@ -118,8 +114,7 @@ class UpdatableSuite extends TestData { assert(t2.update(Map("n" -> lit(0)), t2("L") === sd("c"), sd) == UpdateResult(4, 0)) checkAnswer( t2, - Seq(Row(0, "A"), Row(0, "B"), Row(0, "D"), Row(0, "E"), Row(3, "C"), Row(6, "F")) - ) + Seq(Row(0, "A"), Row(0, "B"), Row(0, "D"), Row(0, "E"), Row(3, "C"), Row(6, "F"))) } test("update with join involving ambiguous columns") { @@ -131,8 +126,7 @@ class UpdatableSuite extends TestData { checkAnswer( t1, Seq(Row(0, 1), Row(0, 2), Row(0, 1), Row(0, 2), Row(3, 1), Row(3, 2)), - sort = false - ) + sort = false) upperCaseData.write.mode(SaveMode.Overwrite).saveAsTable(tableName3) val up = session.table(tableName3) @@ -141,8 +135,7 @@ class UpdatableSuite extends TestData { assert(up.update(Map("n" -> lit(0)), up("L") === sd("L"), sd) == UpdateResult(4, 0)) checkAnswer( up, - Seq(Row(0, "A"), Row(0, "B"), Row(0, "D"), Row(0, "E"), Row(3, "C"), Row(6, "F")) - ) + Seq(Row(0, "A"), Row(0, "B"), Row(0, "D"), Row(0, "E"), Row(3, "C"), Row(6, "F"))) } test("update with join with aggregated source data") { @@ -155,8 +148,7 @@ class UpdatableSuite extends TestData { val b = src.groupBy(col("k")).agg(min(col("v")).as("v")) assert( target.update(Map(target("v") -> b("v")), target("k") === b("k"), b) - == UpdateResult(1, 0) - ) + == UpdateResult(1, 0)) checkAnswer(target, Seq(Row(0, 11))) } @@ -215,8 +207,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenMatched .update(Map(target("desc") -> source("desc"))) - .collect() == MergeResult(0, 2, 0) - ) + .collect() == MergeResult(0, 2, 0)) checkAnswer(target, Seq(Row(10, "new"), Row(10, "new"), Row(11, "old"))) targetDF.write.mode(SaveMode.Overwrite).saveAsTable(tableName) @@ -225,8 +216,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenMatched .update(Map("desc" -> source("desc"))) - .collect() == MergeResult(0, 2, 0) - ) + .collect() == MergeResult(0, 2, 0)) checkAnswer(target, Seq(Row(10, "new"), Row(10, "new"), Row(11, "old"))) targetDF.write.mode(SaveMode.Overwrite).saveAsTable(tableName) @@ -235,8 +225,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenMatched(target("desc") === lit("old")) .update(Map(target("desc") -> source("desc"))) - .collect() == MergeResult(0, 1, 0) - ) + .collect() == MergeResult(0, 1, 0)) checkAnswer(target, Seq(Row(10, "new"), Row(10, "too_old"), Row(11, "old"))) } @@ -252,8 +241,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenMatched .delete() - .collect() == MergeResult(0, 0, 2) - ) + .collect() == MergeResult(0, 0, 2)) checkAnswer(target, Seq(Row(11, "old"))) targetDF.write.mode(SaveMode.Overwrite).saveAsTable(tableName) @@ -262,8 +250,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenMatched(target("desc") === lit("old")) .delete() - .collect() == MergeResult(0, 0, 1) - ) + .collect() == MergeResult(0, 0, 1)) checkAnswer(target, Seq(Row(10, "too_old"), Row(11, "old"))) } @@ -279,8 +266,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenNotMatched .insert(Map(target("id") -> source("id"), target("desc") -> source("desc"))) - .collect() == MergeResult(2, 0, 0) - ) + .collect() == MergeResult(2, 0, 0)) checkAnswer(target, Seq(Row(10, "old"), Row(11, "new"), Row(12, "new"), Row(12, "old"))) targetDF.write.mode(SaveMode.Overwrite).saveAsTable(tableName) @@ -289,8 +275,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenNotMatched .insert(Seq(source("id"), source("desc"))) - .collect() == MergeResult(2, 0, 0) - ) + .collect() == MergeResult(2, 0, 0)) checkAnswer(target, Seq(Row(10, "old"), Row(11, "new"), Row(12, "new"), Row(12, "old"))) targetDF.write.mode(SaveMode.Overwrite).saveAsTable(tableName) @@ -299,8 +284,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenNotMatched .insert(Map("id" -> source("id"), "desc" -> source("desc"))) - .collect() == MergeResult(2, 0, 0) - ) + .collect() == MergeResult(2, 0, 0)) checkAnswer(target, Seq(Row(10, "old"), Row(11, "new"), Row(12, "new"), Row(12, "old"))) targetDF.write.mode(SaveMode.Overwrite).saveAsTable(tableName) @@ -309,8 +293,7 @@ class UpdatableSuite extends TestData { .merge(source, target("id") === source("id")) .whenNotMatched(source("desc") === lit("new")) .insert(Map(target("id") -> source("id"), target("desc") -> source("desc"))) - .collect() == MergeResult(1, 0, 0) - ) + .collect() == MergeResult(1, 0, 0)) checkAnswer(target, Seq(Row(10, "old"), Row(11, "new"), Row(12, "new"))) } @@ -332,8 +315,7 @@ class UpdatableSuite extends TestData { .insert(Map(target("id") -> source("id"), target("desc") -> lit("new"))) .whenNotMatched .insert(Map(target("id") -> source("id"), target("desc") -> source("desc"))) - .collect() == MergeResult(2, 1, 1) - ) + .collect() == MergeResult(2, 1, 1)) checkAnswer(target, Seq(Row(10, "new"), Row(11, "old"), Row(12, "new"), Row(13, "new"))) } @@ -352,8 +334,7 @@ class UpdatableSuite extends TestData { .update(Map(target("v") -> source("v"))) .whenNotMatched .insert(Map(target("k") -> source("k"), target("v") -> source("v"))) - .collect() == MergeResult(0, 1, 0) - ) + .collect() == MergeResult(0, 1, 0)) checkAnswer(target, Seq(Row(0, 12))) } @@ -384,12 +365,10 @@ class UpdatableSuite extends TestData { .insert(Map(target("v") -> source("v"))) .whenNotMatched .insert(Map("k" -> source("k"), "v" -> source("v"))) - .collect() == MergeResult(4, 2, 2) - ) + .collect() == MergeResult(4, 2, 2)) checkAnswer( target, - Seq(Row(1, 21), Row(3, 3), Row(4, 4), Row(5, 25), Row(7, null), Row(null, 26)) - ) + Seq(Row(1, 21), Row(3, 3), Row(4, 4), Row(5, 25), Row(7, null), Row(null, 26))) } test("clone") { diff --git a/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala index 56554a06..5c35d624 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala @@ -24,8 +24,7 @@ class ViewSuite extends TestData { checkAnswer( session.sql(s"select * from $viewName1"), Seq(Row(1.111), Row(2.222), Row(3.333)), - sort = false - ) + sort = false) } test("view name with special character") { @@ -33,14 +32,12 @@ class ViewSuite extends TestData { checkAnswer( session.sql(s"select * from ${quoteName(viewName2)}"), Seq(Row(1, 2), Row(3, 4)), - sort = false - ) + sort = false) } test("only works on select") { assertThrows[IllegalArgumentException]( - session.sql("show tables").createOrReplaceView(viewName1) - ) + session.sql("show tables").createOrReplaceView(viewName1)) } // getDatabaseFromProperties will read local files, which is not supported in Java SP yet. diff --git a/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala b/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala index 13a824c6..89e8cbd4 100644 --- a/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala @@ -20,8 +20,7 @@ class WindowFramesSuite extends TestData { checkAnswer( df.select($"key", lead($"value", 1).over(window), lag($"value", 1).over(window)), - Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil - ) + Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) } test("reverse lead/lag with positive offset") { @@ -30,8 +29,7 @@ class WindowFramesSuite extends TestData { checkAnswer( df.select($"key", lead($"value", 1).over(window), lag($"value", 1).over(window)), - Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil - ) + Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) } test("lead/lag with negative offset") { @@ -40,8 +38,7 @@ class WindowFramesSuite extends TestData { checkAnswer( df.select($"key", lead($"value", -1).over(window), lag($"value", -1).over(window)), - Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil - ) + Row(1, "1", null) :: Row(1, null, "3") :: Row(2, "2", null) :: Row(2, null, "4") :: Nil) } test("reverse lead/lag with negative offset") { @@ -50,8 +47,7 @@ class WindowFramesSuite extends TestData { checkAnswer( df.select($"key", lead($"value", -1).over(window), lag($"value", -1).over(window)), - Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil - ) + Row(1, "3", null) :: Row(1, null, "1") :: Row(2, "4", null) :: Row(2, null, "2") :: Nil) } test("lead/lag with default value") { @@ -65,12 +61,10 @@ class WindowFramesSuite extends TestData { lead($"value", 2).over(window), lag($"value", 2).over(window), lead($"value", -2).over(window), - lag($"value", -2).over(window) - ), + lag($"value", -2).over(window)), Row(1, default, default, default, default) :: Row(1, default, default, default, default) :: Row(2, "5", default, default, "5") :: Row(2, default, "2", "2", default) :: - Row(2, default, default, default, default) :: Nil - ) + Row(2, default, default, default, default) :: Nil) } test("unbounded rows/range between with aggregation") { @@ -83,10 +77,8 @@ class WindowFramesSuite extends TestData { sum($"value") .over(window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), sum($"value") - .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - ), - Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil - ) + .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("one", 4, 4) :: Row("one", 4, 4) :: Row("two", 6, 6) :: Row("two", 6, 6) :: Nil) } test("SN - rows between should accept int/long values as boundary") { @@ -97,21 +89,17 @@ class WindowFramesSuite extends TestData { checkAnswer( df.select( $"key", - count($"key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 100)) - ), - Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(2147483650L, 1), Row(2147483650L, 1), Row(3, 2)) - ) + count($"key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 100))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(2147483650L, 1), Row(2147483650L, 1), Row(3, 2))) val e = intercept[SnowparkClientException]( df.select( $"key", count($"key") - .over(Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)) - ) - ) + .over(Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) assert( - e.message.contains("The ending point for the window frame is not a valid integer: 2147483648") - ) + e.message.contains( + "The ending point for the window frame is not a valid integer: 2147483648")) } @@ -123,22 +111,17 @@ class WindowFramesSuite extends TestData { df.select( $"key", min($"key") - .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - ), - Seq(Row(1, 1)) - ) + .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq(Row(1, 1))) intercept[SnowflakeSQLException]( - df.select(min($"key").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect - ) + df.select(min($"key").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect) intercept[SnowflakeSQLException]( - df.select(min($"key").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect - ) + df.select(min($"key").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect) intercept[SnowflakeSQLException]( - df.select(min($"key").over(window.rangeBetween(-1, 1))).collect - ) + df.select(min($"key").over(window.rangeBetween(-1, 1))).collect) } test("SN - range between should accept numeric values only when bounded") { @@ -149,23 +132,18 @@ class WindowFramesSuite extends TestData { df.select( $"value", min($"value") - .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - ), - Row("non_numeric", "non_numeric") :: Nil - ) + .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("non_numeric", "non_numeric") :: Nil) // TODO: Add another test with eager mode enabled intercept[SnowflakeSQLException]( - df.select(min($"value").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect() - ) + df.select(min($"value").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect()) intercept[SnowflakeSQLException]( - df.select(min($"value").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect() - ) + df.select(min($"value").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect()) intercept[SnowflakeSQLException]( - df.select(min($"value").over(window.rangeBetween(-1, 1))).collect() - ) + df.select(min($"value").over(window.rangeBetween(-1, 1))).collect()) } test("SN - sliding rows between with aggregation") { @@ -175,8 +153,7 @@ class WindowFramesSuite extends TestData { checkAnswer( df.select($"key", avg($"key").over(window)), Row(1, 1.333) :: Row(1, 1.333) :: Row(2, 1.500) :: Row(2, 2.000) :: - Row(2, 2.000) :: Nil - ) + Row(2, 2.000) :: Nil) } test("SN - reverse sliding rows between with aggregation") { @@ -186,8 +163,7 @@ class WindowFramesSuite extends TestData { checkAnswer( df.select($"key", avg($"key").over(window)), Row(1, 1.000) :: Row(1, 1.333) :: Row(2, 1.333) :: Row(2, 2.000) :: - Row(2, 2.000) :: Nil - ) + Row(2, 2.000) :: Nil) } test("Window function any_value()") { @@ -200,8 +176,7 @@ class WindowFramesSuite extends TestData { .select( $"key", any_value($"value1").over(Window.partitionBy($"key")), - any_value($"value2").over(Window.partitionBy($"key")) - ) + any_value($"value2").over(Window.partitionBy($"key"))) .collect() assert(rows.length == 4) rows.foreach { row => diff --git a/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala b/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala index 02906237..00100566 100644 --- a/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala @@ -17,15 +17,13 @@ class WindowSpecSuite extends TestData { checkAnswer( df.select($"key", avg($"key").over(window)), Row(1, 1.333) :: Row(1, 1.333) :: Row(2, 1.500) :: Row(2, 2.000) :: - Row(2, 2.000) :: Nil - ) + Row(2, 2.000) :: Nil) val window2 = Window.rowsBetween(Window.currentRow, 2).orderBy($"key") checkAnswer( df.select($"key", avg($"key").over(window2)), Seq(Row(2, 2.000), Row(2, 2.000), Row(2, 2.000), Row(1, 1.666), Row(1, 1.333)), - sort = false - ) + sort = false) } test("rangeBetween") { @@ -36,17 +34,14 @@ class WindowSpecSuite extends TestData { df.select( $"value", min($"value") - .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - ), - Row("non_numeric", "non_numeric") :: Nil - ) + .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Row("non_numeric", "non_numeric") :: Nil) val window2 = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy($"value") checkAnswer( df.select($"value", min($"value").over(window2)), - Row("non_numeric", "non_numeric") :: Nil - ) + Row("non_numeric", "non_numeric") :: Nil) } test("window function with aggregates") { @@ -56,8 +51,7 @@ class WindowSpecSuite extends TestData { checkAnswer( df.groupBy($"key") .agg(sum($"value"), sum(sum($"value")).over(window) - sum($"value")), - Seq(Row("a", 6, 9), Row("b", 9, 6)) - ) + Seq(Row("a", 6, 9), Row("b", 9, 6))) } test("Window functions inside WHERE and HAVING clauses") { @@ -68,47 +62,36 @@ class WindowSpecSuite extends TestData { } checkAnalysisError[SnowflakeSQLException]( - testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1) - ) + testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1)) checkAnalysisError[SnowflakeSQLException]( - testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1) - ) + testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1)) checkAnalysisError[SnowflakeSQLException]( testData2 .groupBy($"a") .agg(avg($"b").as("avgb")) - .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1) - ) + .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1)) checkAnalysisError[SnowflakeSQLException]( testData2 .groupBy($"a") .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where(rank().over(Window.orderBy($"a")) === 1) - ) + .where(rank().over(Window.orderBy($"a")) === 1)) checkAnalysisError[SnowflakeSQLException]( testData2 .groupBy($"a") .agg(max($"b").as("maxb"), sum($"b").as("sumb")) - .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1) - ) + .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1)) testData2.createOrReplaceTempView("testData2") checkAnalysisError[SnowflakeSQLException]( - session.sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1") - ) + session.sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1")) checkAnalysisError[SnowflakeSQLException]( - session.sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1") - ) + session.sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1")) checkAnalysisError[SnowflakeSQLException]( session.sql( - "SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1" - ) - ) + "SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1")) checkAnalysisError[SnowflakeSQLException]( session.sql( - "SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1" - ) - ) + "SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1")) checkAnalysisError[SnowflakeSQLException](session.sql(s"""SELECT a, MAX(b) |FROM testData2 |GROUP BY a @@ -122,8 +105,7 @@ class WindowSpecSuite extends TestData { checkAnswer( df.select(lead($"key", 1).over(w), lead($"value", 1).over(w)), - Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil - ) + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } test("reuse window orderBy") { @@ -132,8 +114,7 @@ class WindowSpecSuite extends TestData { checkAnswer( df.select(lead($"key", 1).over(w), lead($"value", 1).over(w)), - Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil - ) + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) } test("rank functions in unspecific window") { @@ -152,13 +133,11 @@ class WindowSpecSuite extends TestData { dense_rank().over(Window.partitionBy($"value").orderBy($"key")), rank().over(Window.partitionBy($"value").orderBy($"key")), cume_dist().over(Window.partitionBy($"value").orderBy($"key")), - percent_rank().over(Window.partitionBy($"value").orderBy($"key")) - ), + percent_rank().over(Window.partitionBy($"value").orderBy($"key"))), Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: - Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil - ) + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("empty over spec") { @@ -167,12 +146,10 @@ class WindowSpecSuite extends TestData { df.createOrReplaceTempView("window_table") checkAnswer( df.select($"key", $"value", sum($"value").over(), avg($"value").over()), - Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5)) - ) + Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5))) checkAnswer( session.sql("select key, value, sum(value) over(), avg(value) over() from window_table"), - Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5)) - ) + Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5))) } test("null inputs") { @@ -188,18 +165,15 @@ class WindowSpecSuite extends TestData { Row("a", 2, null, null), Row("b", 4, null, null), Row("b", 3, null, null), - Row("b", 2, null, null) - ), - false - ) + Row("b", 2, null, null)), + false) } test("SN - window function should fail if order by clause is not specified") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") val e = intercept[SnowflakeSQLException]( // Here we missed .orderBy("key")! - df.select(row_number().over(Window.partitionBy($"value"))).collect() - ) + df.select(row_number().over(Window.partitionBy($"value"))).collect()) } test("SN - corr, covar_pop, stddev_pop functions in specific window") { @@ -212,8 +186,7 @@ class WindowSpecSuite extends TestData { ("f", "p3", 6.0, 12.0), ("g", "p3", 6.0, 12.0), ("h", "p3", 8.0, 16.0), - ("i", "p4", 5.0, 5.0) - ).toDF("key", "partitionId", "value1", "value2") + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") checkAnswer( df.select( $"key", @@ -221,44 +194,37 @@ class WindowSpecSuite extends TestData { Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), covar_pop($"value1", $"value2") .over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), var_pop($"value1") .over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), stddev_pop($"value1") .over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), var_pop($"value2") .over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), stddev_pop($"value2") .over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), Seq( Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), @@ -268,9 +234,7 @@ class WindowSpecSuite extends TestData { Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), - Row("i", null, 0.0, 0.0, 0.0, 0.0, 0.0) - ) - ) + Row("i", null, 0.0, 0.0, 0.0, 0.0, 0.0))) } test("SN - covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { @@ -283,8 +247,7 @@ class WindowSpecSuite extends TestData { ("f", "p3", 6.0, 12.0), ("g", "p3", 6.0, 12.0), ("h", "p3", 8.0, 16.0), - ("i", "p4", 5.0, 5.0) - ).toDF("key", "partitionId", "value1", "value2") + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") checkAnswer( df.select( $"key", @@ -292,33 +255,27 @@ class WindowSpecSuite extends TestData { Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), var_samp($"value1").over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), variance($"value1").over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), stddev_samp($"value1").over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), stddev($"value1").over( Window .partitionBy($"partitionId") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), Seq( Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), @@ -328,9 +285,7 @@ class WindowSpecSuite extends TestData { Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544), - Row("i", null, null, null, null, null) - ) - ) + Row("i", null, null, null, null, null))) } test("SN - skewness and kurtosis functions in window") { @@ -344,8 +299,7 @@ class WindowSpecSuite extends TestData { ("g", "p1", 3.0), ("h", "p2", 1.0), ("i", "p2", 2.0), - ("j", "p2", 5.0) - ).toDF("key", "partition", "value") + ("j", "p2", 5.0)).toDF("key", "partition", "value") checkAnswer( df.select( $"key", @@ -353,15 +307,12 @@ class WindowSpecSuite extends TestData { Window .partitionBy($"partition") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), kurtosis($"value").over( Window .partitionBy($"partition") .orderBy($"key") - .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) - ) - ), + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() Seq( Row("a", -0.353044463087872, -1.8166064369747783), @@ -373,9 +324,7 @@ class WindowSpecSuite extends TestData { Row("g", -0.353044463087872, -1.8166064369747783), Row("h", 1.293342780733395, null), Row("i", 1.293342780733395, null), - Row("j", 1.293342780733395, null) - ) - ) + Row("j", 1.293342780733395, null))) } test("SN - aggregation function on invalid column") { @@ -393,11 +342,9 @@ class WindowSpecSuite extends TestData { $"key", var_pop($"value").over(window), var_samp($"value").over(window), - approx_count_distinct($"value").over(window) - ), + approx_count_distinct($"value").over(window)), Seq.fill(4)(Row("a", BigDecimal(0.250000), BigDecimal(0.333333), 2)) - ++ Seq.fill(3)(Row("b", BigDecimal(0.666667), BigDecimal(1.000000), 3)) - ) + ++ Seq.fill(3)(Row("b", BigDecimal(0.666667), BigDecimal(1.000000), 3))) } test("SN - window functions in multiple selects") { @@ -417,9 +364,7 @@ class WindowSpecSuite extends TestData { Row("S1", "P1", 100, 800, 800), Row("S1", "P1", 700, 800, 800), Row("S2", "P1", 200, 200, 500), - Row("S2", "P2", 300, 300, 500) - ) - ) + Row("S2", "P2", 300, 300, 500))) }