Skip to content

Commit

Permalink
update test format
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed Aug 21, 2024
1 parent 5d9d18d commit 99792d4
Show file tree
Hide file tree
Showing 72 changed files with 2,581 additions and 5,330 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/precommit-code-verification.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: precommit test - code verification
on:
push:
branches: [ main ]
pull_request:
branches: '**'

jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup JDK
uses: actions/setup-java@v3
with:
distribution: temurin
java-version: 8
- name: Decrypt profile.properties
run: .github/scripts/decrypt_profile.sh
env:
PROFILE_PASSWORD: ${{ secrets.PROFILE_PASSWORD }}
- name: Run test
run: sbt CodeVerificationTests:test
28 changes: 25 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ val openTelemetryVersion = "1.41.0"
val slf4jVersion = "2.0.4"

lazy val root = (project in file("."))
.configs(CodeVerificationTests)
.settings(
name := "snowpark",
version := "1.15.0-SNAPSHOT",
Expand All @@ -31,7 +32,7 @@ lazy val root = (project in file("."))
"commons-codec" % "commons-codec" % "1.17.0",
"io.opentelemetry" % "opentelemetry-api" % openTelemetryVersion,
"net.snowflake" % "snowflake-jdbc" % "3.17.0",
"com.github.vertical-blank" % "sql-formatter" % "2.0.5",
"com.github.vertical-blank" % "sql-formatter" % "1.0.2",
"com.fasterxml.jackson.core" % "jackson-databind" % jacksonVersion,
"com.fasterxml.jackson.core" % "jackson-core" % jacksonVersion,
"com.fasterxml.jackson.core" % "jackson-annotations" % jacksonVersion,
Expand All @@ -48,8 +49,12 @@ lazy val root = (project in file("."))
javafmtOnCompile := true,
Test / testOptions := Seq(Tests.Argument(TestFrameworks.JUnit, "-a")),
// Test / crossPaths := false,
Test / fork := true,
Test / javaOptions ++= Seq("-Xms1024M", "-Xmx4096M"),
Test / fork := false,
// Test / javaOptions ++= Seq("-Xms1024M", "-Xmx4096M"),
inConfig(CodeVerificationTests)(Defaults.testTasks),
CodeVerificationTests / testOptions += Tests.Filter(isCodeVerification),
// default test
Test / testOptions += Tests.Filter(isRemainingTest),
// Release settings
// usePgpKeyHex(Properties.envOrElse("GPG_SIGNATURE", "12345")),
Global / pgpPassphrase := Properties.envOrNone("GPG_KEY_PASSPHRASE").map(_.toCharArray),
Expand Down Expand Up @@ -80,3 +85,20 @@ lazy val root = (project in file("."))
}
)
)

// Test Groups
// Code Verification
def isCodeVerification(name: String): Boolean = {
name.startsWith("com.snowflake.code_verification")
}
lazy val CodeVerificationTests = config("CodeVerificationTests") extend Test


// Java API Tests
// Java UDx Tests
// Scala UDx Tests
// FIPS Tests

// other Tests
def isRemainingTest(name: String): Boolean = name.endsWith("JavaAPISuite")
// ! isCodeVerification(name)
22 changes: 22 additions & 0 deletions scripts/format_checker.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/bin/bash -ex

# format src
sbt clean compile

if [ -z "$(git status --porcelain)" ]; then
Expand All @@ -9,3 +10,24 @@ else
echo "Run 'sbt clean compile' to reformat"
exit 1
fi

# format scala test
sbt test:scalafmt
if [ -z "$(git status --porcelain)" ]; then
echo "Scala Test Code Format Check: Passed!"
else
echo "Scala Test Code Format Check: Failed!"
echo "Run 'sbt test:scalafmt' to reformat"
exit 1
fi

# format java test
sbt test:javafmt
if [ -z "$(git status --porcelain)" ]; then
echo "Scala Test Code Format Check: Passed!"
else
echo "Scala Test Code Format Check: Failed!"
echo "Run 'sbt test:scalafmt' to reformat"
exit 1
fi

23 changes: 11 additions & 12 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2858,8 +2858,8 @@ object functions {
* specified string column. If the regex did not match, or the specified group did not match, an
* empty string is returned. <pr>Example: from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show() </pr> <pr> ---------
* \|"RES" | ---------
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show() </pr> <pr> --------- \|"RES"
* \| ---------
* | 20 |
* |:---|
* | 40 |
Expand Down Expand Up @@ -2896,9 +2896,9 @@ object functions {
* Args: col: The column to evaluate its sign <pr> Example:: >>> df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show() ----------------------------------
* \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ----------------------------------
* \|-1 |1 |0 | ---------------------------------- </pr>
* sign("c").alias("c_sign")).show() ---------------------------------- \|"A_SIGN" |"B_SIGN"
* \|"C_SIGN" | ---------------------------------- \|-1 |1 |0 |
* ---------------------------------- </pr>
* @since 1.14.0
* @param e
* Column to calculate the sign.
Expand All @@ -2918,9 +2918,9 @@ object functions {
* Args: col: The column to evaluate its sign <pr> Example:: >>> df =
* session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
* df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show() ----------------------------------
* \|"A_SIGN" |"B_SIGN" |"C_SIGN" | ----------------------------------
* \|-1 |1 |0 | ---------------------------------- </pr>
* sign("c").alias("c_sign")).show() ---------------------------------- \|"A_SIGN" |"B_SIGN"
* \|"C_SIGN" | ---------------------------------- \|-1 |1 |0 |
* ---------------------------------- </pr>
* @since 1.14.0
* @param e
* Column to calculate the sign.
Expand Down Expand Up @@ -2973,8 +2973,8 @@ object functions {

/** Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
* returned. <pr> Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show() ------------
* \|"RESULT" | ------------
* >>> df.select(array_agg("a", True).alias("result")).show() ------------ \|"RESULT" |
* ------------
* | [ |
* |:---|
* | 1, |
Expand All @@ -2994,8 +2994,7 @@ object functions {
* returned.
*
* Example:: >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) >>>
* df.select(array_agg("a", True).alias("result")).show() ------------
* \|"RESULT" | ------------
* df.select(array_agg("a", True).alias("result")).show() ------------ \|"RESULT" | ------------
* | [ |
* |:---|
* | 1, |
Expand Down
41 changes: 0 additions & 41 deletions src/test/java/com/snowflake/snowpark/TestRunner.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ object ClassUtils extends Logging {
class2: Class[B],
class1Only: Set[String] = Set.empty,
class2Only: Set[String] = Set.empty,
class1To2NameMap: Map[String, String] = Map.empty
): Boolean = {
class1To2NameMap: Map[String, String] = Map.empty): Boolean = {
val nameList1 = getAllPublicFunctionNames(class1)
val nameList2 = mutable.Set[String](getAllPublicFunctionNames(class2): _*)
var missed = false
Expand All @@ -63,8 +62,7 @@ object ClassUtils extends Logging {
list2Cache.remove(name)
} else {
logError(s"${class1.getName} misses function $name")
}
)
})
!missed && list2Cache.isEmpty
}
}
Loading

0 comments on commit 99792d4

Please sign in to comment.