Skip to content

Commit

Permalink
SIT-1382: Add support for `com.snowflake.snowpark.Session.updateQuery…
Browse files Browse the repository at this point in the history
…Tag` function (#115)
  • Loading branch information
sfc-gh-fgonzalezmendez authored Jul 1, 2024
1 parent 76cfb94 commit 11031af
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 2 deletions.
6 changes: 6 additions & 0 deletions fips-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
<scalaPluginVersion>4.3.0</scalaPluginVersion>
<jackson.version>2.13.2</jackson.version>
<jackson.databind.version>2.13.4.2</jackson.databind.version>
<jackson.module.scala.version>2.13.5</jackson.module.scala.version>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -144,6 +145,11 @@
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.12</artifactId>
<version>${jackson.module.scala.version}</version>
</dependency>

<!-- Test -->
<dependency>
Expand Down
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
<scalaPluginVersion>4.3.0</scalaPluginVersion>
<jackson.version>2.13.2</jackson.version>
<jackson.databind.version>2.13.4.2</jackson.databind.version>
<jackson.module.scala.version>2.13.5</jackson.module.scala.version>

</properties>
<dependencyManagement>
<dependencies>
Expand Down Expand Up @@ -131,6 +133,11 @@
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.12</artifactId>
<version>${jackson.module.scala.version}</version>
</dependency>

<!-- Test -->
<dependency>
Expand Down
44 changes: 44 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Session.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.snowpark_java;

import com.snowflake.snowpark.PublicPreview;
import com.snowflake.snowpark.SnowparkClientException;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.types.InternalUtils;
import com.snowflake.snowpark_java.types.StructType;
Expand Down Expand Up @@ -297,6 +298,49 @@ public void unsetQueryTag() {
session.unsetQueryTag();
}

/**
* Updates the query tag that is a JSON encoded string for the current session.
*
* <p>Keep in mind that assigning a value via {@link Session#setQueryTag(String)} will remove any
* current query tag state.
*
* <p>Example 1:
*
* <pre>{@code
* session.setQueryTag("{\"key1\":\"value1\"}");
* session.updateQueryTag("{\"key2\":\"value2\"}");
* System.out.println(session.getQueryTag().get());
* {"key1":"value1","key2":"value2"}
* }</pre>
*
* <p>Example 2:
*
* <pre>{@code
* session.sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect();
* session.updateQueryTag("{\"key2\":\"value2\"}");
* System.out.println(session.getQueryTag().get());
* {"key1":"value1","key2":"value2"}
* }</pre>
*
* <p>Example 3:
*
* <pre>{@code
* session.setQueryTag("");
* session.updateQueryTag("{\"key1\":\"value1\"}");
* System.out.println(session.getQueryTag().get());
* {"key1":"value1"}
* }</pre>
*
* @param queryTag A JSON encoded string that provides updates to the current query tag.
* @throws SnowparkClientException If the provided query tag or the query tag of the current
* session are not valid JSON strings; or if it could not serialize the query tag into a JSON
* string.
* @since 1.13.0
*/
public void updateQueryTag(String queryTag) throws SnowparkClientException {
session.updateQueryTag(queryTag);
}

/**
* Creates a new DataFrame via Generator function.
*
Expand Down
90 changes: 90 additions & 0 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.snowflake.snowpark

import com.fasterxml.jackson.databind.json.JsonMapper
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}

import java.io.{File, FileInputStream, FileNotFoundException}
import java.net.URI
import java.sql.{Connection, Date, Time, Timestamp}
Expand All @@ -26,6 +29,7 @@ import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, Snowfl
import scala.concurrent.{ExecutionContext, Future}
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

/**
*
Expand Down Expand Up @@ -61,6 +65,11 @@ import scala.reflect.runtime.universe.TypeTag
* @since 0.1.0
*/
class Session private (private[snowpark] val conn: ServerConnection) extends Logging {
private val jsonMapper = JsonMapper
.builder()
.addModule(DefaultScalaModule)
.build() :: ClassTagExtensions

private val STAGE_PREFIX = "@"
// URI and file name with md5
private val classpathURIs = new ConcurrentHashMap[URI, Option[String]]().asScala
Expand Down Expand Up @@ -321,6 +330,87 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
*/
def getQueryTag(): Option[String] = this.conn.getQueryTag()

/**
* Updates the query tag that is a JSON encoded string for the current session.
*
* Keep in mind that assigning a value via [[setQueryTag]] will remove any current query tag
* state.
*
* Example 1:
* {{{
* session.setQueryTag("""{"key1":"value1"}""")
* session.updateQueryTag("""{"key2":"value2"}""")
* print(session.getQueryTag().get)
* {"key1":"value1","key2":"value2"}
* }}}
*
* Example 2:
* {{{
* session.sql("""ALTER SESSION SET QUERY_TAG = '{"key1":"value1"}'""").collect()
* session.updateQueryTag("""{"key2":"value2"}""")
* print(session.getQueryTag().get)
* {"key1":"value1","key2":"value2"}
* }}}
*
* Example 3:
* {{{
* session.setQueryTag("")
* session.updateQueryTag("""{"key1":"value1"}""")
* print(session.getQueryTag().get)
* {"key1":"value1"}
* }}}
*
* @param queryTag A JSON encoded string that provides updates to the current query tag.
* @throws SnowparkClientException If the provided query tag or the query tag of the current
* session are not valid JSON strings; or if it could not
* serialize the query tag into a JSON string.
* @since 1.13.0
*/
def updateQueryTag(queryTag: String): Unit = synchronized {
val newQueryTagMap = parseJsonString(queryTag)
if (newQueryTagMap.isEmpty) {
throw ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG()
}

var currentQueryTag = this.conn.getParameterValue("query_tag")
currentQueryTag = if (currentQueryTag.isEmpty) "{}" else currentQueryTag

val currentQueryTagMap = parseJsonString(currentQueryTag)
if (currentQueryTagMap.isEmpty) {
throw ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag)
}

val updatedQueryTagMap = currentQueryTagMap.get ++ newQueryTagMap.get
val updatedQueryTagStr = toJsonString(updatedQueryTagMap)
if (updatedQueryTagStr.isEmpty) {
throw ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG()
}

setQueryTag(updatedQueryTagStr.get)
}

/**
* Attempts to parse a JSON-encoded string into a [[scala.collection.immutable.Map]].
*
* @param jsonString The JSON-encoded string to parse.
* @return An `Option` containing the `Map` if the parsing of the JSON string was
* successful, or `None` otherwise.
*/
private def parseJsonString(jsonString: String): Option[Map[String, Any]] = {
Try(jsonMapper.readValue[Map[String, Any]](jsonString)).toOption
}

/**
* Attempts to convert a [[scala.collection.immutable.Map]] into a JSON-encoded string.
*
* @param map The `Map` to convert.
* @return An `Option` containing the JSON-encoded string if the conversion was successful,
* or `None` otherwise.
*/
private def toJsonString(map: Map[String, Any]): Option[String] = {
Try(jsonMapper.writeValueAsString(map)).toOption
}

/*
* Checks that the latest version of all jar dependencies is
* uploaded to a stage and returns the staged URLs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ private[snowpark] object ErrorMessage {
"""Invalid input argument type, the input argument type of Explode function should be either Map or Array types.
|The input argument type: %s
|""".stripMargin,
"0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.")
"0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.",
"0426" -> "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.",
"0427" -> "The query tag of the current session must be a valid JSON string. Current query tag: %s",
"0428" -> "Failed to serialize the query tag into a JSON string.")
// scalastyle:on

/*
Expand Down Expand Up @@ -409,6 +412,15 @@ private[snowpark] object ErrorMessage {
def MISC_UNSUPPORTED_GEOMETRY_FORMAT(typeName: String): SnowparkClientException =
createException("0425", typeName)

def MISC_INVALID_INPUT_QUERY_TAG(): SnowparkClientException =
createException("0426")

def MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag: String): SnowparkClientException =
createException("0427", currentQueryTag)

def MISC_FAILED_TO_SERIALIZE_QUERY_TAG(): SnowparkClientException =
createException("0428")

/**
* Create Snowpark client Exception.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.snowflake.snowpark.internal
import com.fasterxml.jackson.annotation.JsonView
import com.fasterxml.jackson.core.TreeNode
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.module.scala.DefaultScalaModule

import java.io.File
import java.net.{URI, URLClassLoader}
Expand All @@ -24,6 +25,8 @@ object UDFClassPath extends Logging {
val jacksonDatabindClass: Class[JsonNode] = classOf[com.fasterxml.jackson.databind.JsonNode]
val jacksonCoreClass: Class[TreeNode] = classOf[com.fasterxml.jackson.core.TreeNode]
val jacksonAnnotationClass: Class[JsonView] = classOf[com.fasterxml.jackson.annotation.JsonView]
val jacksonModuleScalaClass: Class[DefaultScalaModule] =
classOf[com.fasterxml.jackson.module.scala.DefaultScalaModule]
val jacksonJarSeq = Seq(
RequiredLibrary(
getPathForClass(jacksonDatabindClass),
Expand All @@ -33,7 +36,11 @@ object UDFClassPath extends Logging {
RequiredLibrary(
getPathForClass(jacksonAnnotationClass),
"jackson-annotation",
jacksonAnnotationClass))
jacksonAnnotationClass),
RequiredLibrary(
getPathForClass(jacksonModuleScalaClass),
"jackson-module-scala",
jacksonModuleScalaClass))

/*
* Libraries required to compile java code generated by snowpark for user's lambda.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// to make sure all API can be accessed from public
package com.snowflake.snowpark_test;

import static org.junit.Assert.assertThrows;

import com.snowflake.snowpark.SnowparkClientException;
import com.snowflake.snowpark.TestUtils;
import com.snowflake.snowpark_java.*;
Expand Down Expand Up @@ -56,6 +58,84 @@ public void tags() {
assert !getSession().getQueryTag().isPresent();
}

@Test
public void updateQueryTagAddNewKeyValuePairs() {
String queryTag1 = "{\"key1\":\"value1\"}";
getSession().setQueryTag(queryTag1);

String queryTag2 = "{\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}";
getSession().updateQueryTag(queryTag2);

String expected =
"{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}";
assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(expected);
}

@Test
public void updateQueryTagUpdateKeyValuePairs() {
String queryTag1 = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}";
getSession().setQueryTag(queryTag1);

String queryTag2 = "{\"key2\":\"newValue2\"}";
getSession().updateQueryTag(queryTag2);

String expected = "{\"key1\":\"value1\",\"key2\":\"newValue2\",\"key3\":\"value3\"}";
assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(expected);
}

@Test
public void updateQueryTagEmptySessionQueryTag() {
getSession().setQueryTag("");

String queryTag = "{\"key1\":\"value1\"}";
getSession().updateQueryTag(queryTag);

assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(queryTag);
}

@Test
public void updateQueryTagInvalidInputQueryTag() {
String queryTag = "tag1";

SnowparkClientException exception =
assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag));
assert exception
.getMessage()
.equals(
"Error Code: 0426, Error message: "
+ "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.");
}

@Test
public void updateQueryTagInvalidSessionQueryTag() {
String queryTag1 = "tag1";
getSession().setQueryTag(queryTag1);

String queryTag2 = "{\"key1\":\"value1\"}";
SnowparkClientException exception =
assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag2));
assert exception
.getMessage()
.equals(
"Error Code: 0427, Error message: "
+ "The query tag of the current session must be a valid JSON string. Current query tag: tag1");
}

@Test
public void updateQueryTagFromAlterSession() {
getSession().sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect();

String queryTag2 = "{\"key2\":\"value2\"}";
getSession().updateQueryTag(queryTag2);

String expected = "{\"key1\":\"value1\",\"key2\":\"value2\"}";
assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(expected);
}

@Test
public void dbAndSchema() {
assert getSession()
Expand Down
27 changes: 27 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -859,4 +859,31 @@ class ErrorMessageSuite extends FunSuite {
"Unsupported Geometry output format: KWT." +
" Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON."))
}

test("MISC_INVALID_INPUT_QUERY_TAG") {
val ex = ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0426")))
assert(
ex.message.startsWith(
"Error Code: 0426, Error message: " +
"The given query tag must be a valid JSON string. " +
"Ensure it's correctly formatted as JSON."))
}

test("MISC_INVALID_CURRENT_QUERY_TAG") {
val ex = ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG("myTag")
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0427")))
assert(
ex.message.startsWith(
"Error Code: 0427, Error message: The query tag of the current session " +
"must be a valid JSON string. Current query tag: myTag"))
}

test("MISC_FAILED_TO_SERIALIZE_QUERY_TAG") {
val ex = ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0428")))
assert(
ex.message.startsWith(
"Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string."))
}
}
Loading

0 comments on commit 11031af

Please sign in to comment.