diff --git a/fips-pom.xml b/fips-pom.xml index 4f856617..1db71769 100644 --- a/fips-pom.xml +++ b/fips-pom.xml @@ -44,6 +44,7 @@ 4.3.0 2.13.2 2.13.4.2 + 2.13.5 @@ -144,6 +145,11 @@ jackson-annotations ${jackson.version} + + com.fasterxml.jackson.module + jackson-module-scala_2.12 + ${jackson.module.scala.version} + diff --git a/pom.xml b/pom.xml index c9fdf2b4..284b334c 100644 --- a/pom.xml +++ b/pom.xml @@ -46,6 +46,8 @@ 4.3.0 2.13.2 2.13.4.2 + 2.13.5 + @@ -131,6 +133,11 @@ jackson-annotations ${jackson.version} + + com.fasterxml.jackson.module + jackson-module-scala_2.12 + ${jackson.module.scala.version} + diff --git a/src/main/java/com/snowflake/snowpark_java/Session.java b/src/main/java/com/snowflake/snowpark_java/Session.java index b22f327a..c2f4ef6d 100644 --- a/src/main/java/com/snowflake/snowpark_java/Session.java +++ b/src/main/java/com/snowflake/snowpark_java/Session.java @@ -1,6 +1,7 @@ package com.snowflake.snowpark_java; import com.snowflake.snowpark.PublicPreview; +import com.snowflake.snowpark.SnowparkClientException; import com.snowflake.snowpark.internal.JavaUtils; import com.snowflake.snowpark_java.types.InternalUtils; import com.snowflake.snowpark_java.types.StructType; @@ -297,6 +298,49 @@ public void unsetQueryTag() { session.unsetQueryTag(); } + /** + * Updates the query tag that is a JSON encoded string for the current session. + * + *

Keep in mind that assigning a value via {@link Session#setQueryTag(String)} will remove any + * current query tag state. + * + *

Example 1: + * + *

{@code
+   * session.setQueryTag("{\"key1\":\"value1\"}");
+   * session.updateQueryTag("{\"key2\":\"value2\"}");
+   * System.out.println(session.getQueryTag().get());
+   * {"key1":"value1","key2":"value2"}
+   * }
+ * + *

Example 2: + * + *

{@code
+   * session.sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect();
+   * session.updateQueryTag("{\"key2\":\"value2\"}");
+   * System.out.println(session.getQueryTag().get());
+   * {"key1":"value1","key2":"value2"}
+   * }
+ * + *

Example 3: + * + *

{@code
+   * session.setQueryTag("");
+   * session.updateQueryTag("{\"key1\":\"value1\"}");
+   * System.out.println(session.getQueryTag().get());
+   * {"key1":"value1"}
+   * }
+ * + * @param queryTag A JSON encoded string that provides updates to the current query tag. + * @throws SnowparkClientException If the provided query tag or the query tag of the current + * session are not valid JSON strings; or if it could not serialize the query tag into a JSON + * string. + * @since 1.13.0 + */ + public void updateQueryTag(String queryTag) throws SnowparkClientException { + session.updateQueryTag(queryTag); + } + /** * Creates a new DataFrame via Generator function. * diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index 633c8e42..2b6c580d 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -1,5 +1,8 @@ package com.snowflake.snowpark +import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule} + import java.io.{File, FileInputStream, FileNotFoundException} import java.net.URI import java.sql.{Connection, Date, Time, Timestamp} @@ -26,6 +29,7 @@ import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, Snowfl import scala.concurrent.{ExecutionContext, Future} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag +import scala.util.Try /** * @@ -61,6 +65,11 @@ import scala.reflect.runtime.universe.TypeTag * @since 0.1.0 */ class Session private (private[snowpark] val conn: ServerConnection) extends Logging { + private val jsonMapper = JsonMapper + .builder() + .addModule(DefaultScalaModule) + .build() :: ClassTagExtensions + private val STAGE_PREFIX = "@" // URI and file name with md5 private val classpathURIs = new ConcurrentHashMap[URI, Option[String]]().asScala @@ -321,6 +330,87 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log */ def getQueryTag(): Option[String] = this.conn.getQueryTag() + /** + * Updates the query tag that is a JSON encoded string for the current session. + * + * Keep in mind that assigning a value via [[setQueryTag]] will remove any current query tag + * state. + * + * Example 1: + * {{{ + * session.setQueryTag("""{"key1":"value1"}""") + * session.updateQueryTag("""{"key2":"value2"}""") + * print(session.getQueryTag().get) + * {"key1":"value1","key2":"value2"} + * }}} + * + * Example 2: + * {{{ + * session.sql("""ALTER SESSION SET QUERY_TAG = '{"key1":"value1"}'""").collect() + * session.updateQueryTag("""{"key2":"value2"}""") + * print(session.getQueryTag().get) + * {"key1":"value1","key2":"value2"} + * }}} + * + * Example 3: + * {{{ + * session.setQueryTag("") + * session.updateQueryTag("""{"key1":"value1"}""") + * print(session.getQueryTag().get) + * {"key1":"value1"} + * }}} + * + * @param queryTag A JSON encoded string that provides updates to the current query tag. + * @throws SnowparkClientException If the provided query tag or the query tag of the current + * session are not valid JSON strings; or if it could not + * serialize the query tag into a JSON string. + * @since 1.13.0 + */ + def updateQueryTag(queryTag: String): Unit = synchronized { + val newQueryTagMap = parseJsonString(queryTag) + if (newQueryTagMap.isEmpty) { + throw ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG() + } + + var currentQueryTag = this.conn.getParameterValue("query_tag") + currentQueryTag = if (currentQueryTag.isEmpty) "{}" else currentQueryTag + + val currentQueryTagMap = parseJsonString(currentQueryTag) + if (currentQueryTagMap.isEmpty) { + throw ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag) + } + + val updatedQueryTagMap = currentQueryTagMap.get ++ newQueryTagMap.get + val updatedQueryTagStr = toJsonString(updatedQueryTagMap) + if (updatedQueryTagStr.isEmpty) { + throw ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG() + } + + setQueryTag(updatedQueryTagStr.get) + } + + /** + * Attempts to parse a JSON-encoded string into a [[scala.collection.immutable.Map]]. + * + * @param jsonString The JSON-encoded string to parse. + * @return An `Option` containing the `Map` if the parsing of the JSON string was + * successful, or `None` otherwise. + */ + private def parseJsonString(jsonString: String): Option[Map[String, Any]] = { + Try(jsonMapper.readValue[Map[String, Any]](jsonString)).toOption + } + + /** + * Attempts to convert a [[scala.collection.immutable.Map]] into a JSON-encoded string. + * + * @param map The `Map` to convert. + * @return An `Option` containing the JSON-encoded string if the conversion was successful, + * or `None` otherwise. + */ + private def toJsonString(map: Map[String, Any]): Option[String] = { + Try(jsonMapper.writeValueAsString(map)).toOption + } + /* * Checks that the latest version of all jar dependencies is * uploaded to a stage and returns the staged URLs diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala index e0df3d6b..ea14da1e 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala @@ -159,7 +159,10 @@ private[snowpark] object ErrorMessage { """Invalid input argument type, the input argument type of Explode function should be either Map or Array types. |The input argument type: %s |""".stripMargin, - "0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.") + "0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.", + "0426" -> "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.", + "0427" -> "The query tag of the current session must be a valid JSON string. Current query tag: %s", + "0428" -> "Failed to serialize the query tag into a JSON string.") // scalastyle:on /* @@ -409,6 +412,15 @@ private[snowpark] object ErrorMessage { def MISC_UNSUPPORTED_GEOMETRY_FORMAT(typeName: String): SnowparkClientException = createException("0425", typeName) + def MISC_INVALID_INPUT_QUERY_TAG(): SnowparkClientException = + createException("0426") + + def MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag: String): SnowparkClientException = + createException("0427", currentQueryTag) + + def MISC_FAILED_TO_SERIALIZE_QUERY_TAG(): SnowparkClientException = + createException("0428") + /** * Create Snowpark client Exception. * diff --git a/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala b/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala index aedbc8a6..31c8200f 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala @@ -3,6 +3,7 @@ package com.snowflake.snowpark.internal import com.fasterxml.jackson.annotation.JsonView import com.fasterxml.jackson.core.TreeNode import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.module.scala.DefaultScalaModule import java.io.File import java.net.{URI, URLClassLoader} @@ -24,6 +25,8 @@ object UDFClassPath extends Logging { val jacksonDatabindClass: Class[JsonNode] = classOf[com.fasterxml.jackson.databind.JsonNode] val jacksonCoreClass: Class[TreeNode] = classOf[com.fasterxml.jackson.core.TreeNode] val jacksonAnnotationClass: Class[JsonView] = classOf[com.fasterxml.jackson.annotation.JsonView] + val jacksonModuleScalaClass: Class[DefaultScalaModule] = + classOf[com.fasterxml.jackson.module.scala.DefaultScalaModule] val jacksonJarSeq = Seq( RequiredLibrary( getPathForClass(jacksonDatabindClass), @@ -33,7 +36,11 @@ object UDFClassPath extends Logging { RequiredLibrary( getPathForClass(jacksonAnnotationClass), "jackson-annotation", - jacksonAnnotationClass)) + jacksonAnnotationClass), + RequiredLibrary( + getPathForClass(jacksonModuleScalaClass), + "jackson-module-scala", + jacksonModuleScalaClass)) /* * Libraries required to compile java code generated by snowpark for user's lambda. diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java index e4062ebc..e651c948 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java @@ -2,6 +2,8 @@ // to make sure all API can be accessed from public package com.snowflake.snowpark_test; +import static org.junit.Assert.assertThrows; + import com.snowflake.snowpark.SnowparkClientException; import com.snowflake.snowpark.TestUtils; import com.snowflake.snowpark_java.*; @@ -56,6 +58,84 @@ public void tags() { assert !getSession().getQueryTag().isPresent(); } + @Test + public void updateQueryTagAddNewKeyValuePairs() { + String queryTag1 = "{\"key1\":\"value1\"}"; + getSession().setQueryTag(queryTag1); + + String queryTag2 = "{\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}"; + getSession().updateQueryTag(queryTag2); + + String expected = + "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}"; + assert getSession().getQueryTag().isPresent(); + assert getSession().getQueryTag().get().equals(expected); + } + + @Test + public void updateQueryTagUpdateKeyValuePairs() { + String queryTag1 = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}"; + getSession().setQueryTag(queryTag1); + + String queryTag2 = "{\"key2\":\"newValue2\"}"; + getSession().updateQueryTag(queryTag2); + + String expected = "{\"key1\":\"value1\",\"key2\":\"newValue2\",\"key3\":\"value3\"}"; + assert getSession().getQueryTag().isPresent(); + assert getSession().getQueryTag().get().equals(expected); + } + + @Test + public void updateQueryTagEmptySessionQueryTag() { + getSession().setQueryTag(""); + + String queryTag = "{\"key1\":\"value1\"}"; + getSession().updateQueryTag(queryTag); + + assert getSession().getQueryTag().isPresent(); + assert getSession().getQueryTag().get().equals(queryTag); + } + + @Test + public void updateQueryTagInvalidInputQueryTag() { + String queryTag = "tag1"; + + SnowparkClientException exception = + assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag)); + assert exception + .getMessage() + .equals( + "Error Code: 0426, Error message: " + + "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON."); + } + + @Test + public void updateQueryTagInvalidSessionQueryTag() { + String queryTag1 = "tag1"; + getSession().setQueryTag(queryTag1); + + String queryTag2 = "{\"key1\":\"value1\"}"; + SnowparkClientException exception = + assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag2)); + assert exception + .getMessage() + .equals( + "Error Code: 0427, Error message: " + + "The query tag of the current session must be a valid JSON string. Current query tag: tag1"); + } + + @Test + public void updateQueryTagFromAlterSession() { + getSession().sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect(); + + String queryTag2 = "{\"key2\":\"value2\"}"; + getSession().updateQueryTag(queryTag2); + + String expected = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + assert getSession().getQueryTag().isPresent(); + assert getSession().getQueryTag().get().equals(expected); + } + @Test public void dbAndSchema() { assert getSession() diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 937b93e6..0ad6d802 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -859,4 +859,31 @@ class ErrorMessageSuite extends FunSuite { "Unsupported Geometry output format: KWT." + " Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.")) } + + test("MISC_INVALID_INPUT_QUERY_TAG") { + val ex = ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0426"))) + assert( + ex.message.startsWith( + "Error Code: 0426, Error message: " + + "The given query tag must be a valid JSON string. " + + "Ensure it's correctly formatted as JSON.")) + } + + test("MISC_INVALID_CURRENT_QUERY_TAG") { + val ex = ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG("myTag") + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0427"))) + assert( + ex.message.startsWith( + "Error Code: 0427, Error message: The query tag of the current session " + + "must be a valid JSON string. Current query tag: myTag")) + } + + test("MISC_FAILED_TO_SERIALIZE_QUERY_TAG") { + val ex = ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG() + assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0428"))) + assert( + ex.message.startsWith( + "Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string.")) + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala index 9b82ba7b..3a4a2dbc 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala @@ -239,6 +239,75 @@ class SessionSuite extends SNTestBase { assert(getParameterValue("query_tag", session) == queryTag2) } + test("updateQueryTag when adding new key-value pairs") { + val queryTag1 = """{"key1":"value1"}""" + session.setQueryTag(queryTag1) + + val queryTag2 = """{"key2":"value2","key3":{"key4":0},"key5":{"key6":"value6"}}""" + session.updateQueryTag(queryTag2) + + val expected = { + """{"key1":"value1","key2":"value2","key3":{"key4":0},"key5":{"key6":"value6"}}""" + } + val actual = getParameterValue("query_tag", session) + assert(actual == expected) + } + + test("updateQueryTag when updating an existing key-value pair") { + val queryTag1 = """{"key1":"value1","key2":"value2","key3":"value3"}""" + session.setQueryTag(queryTag1) + + val queryTag2 = """{"key2":"newValue2"}""" + session.updateQueryTag(queryTag2) + + val expected = """{"key1":"value1","key2":"newValue2","key3":"value3"}""" + val actual = getParameterValue("query_tag", session) + assert(actual == expected) + } + + test("updateQueryTag when the query tag of the current session is empty") { + session.setQueryTag("") + + val queryTag = """{"key1":"value1"}""" + session.updateQueryTag(queryTag) + + val actual = getParameterValue("query_tag", session) + assert(actual == queryTag) + } + + test("updateQueryTag when the given query tag is not a valid JSON") { + val queryTag = "tag1" + val exception = intercept[SnowparkClientException](session.updateQueryTag(queryTag)) + assert( + exception.message.startsWith( + "Error Code: 0426, Error message: The given query tag must be a valid JSON string. " + + "Ensure it's correctly formatted as JSON.")) + } + + test("updateQueryTag when the query tag of the current session is not a valid JSON") { + val queryTag1 = "tag1" + session.setQueryTag(queryTag1) + + val queryTag2 = """{"key1":"value1"}""" + val exception = intercept[SnowparkClientException](session.updateQueryTag(queryTag2)) + assert( + exception.message.startsWith( + "Error Code: 0427, Error message: The query tag of the current session must be a valid " + + "JSON string. Current query tag: tag1")) + } + + test("updateQueryTag when the query tag of the current session is set with an ALTER SESSION") { + val queryTag1 = """{"key1":"value1"}""" + session.sql(s"ALTER SESSION SET QUERY_TAG = '$queryTag1'").collect() + + val queryTag2 = """{"key2":"value2"}""" + session.updateQueryTag(queryTag2) + + val expected = """{"key1":"value1","key2":"value2"}""" + val actual = getParameterValue("query_tag", session) + assert(actual == expected) + } + test("Multiple queries test for query tags") { val queryTag = randomName() session.setQueryTag(queryTag)