Skip to content

Commit

Permalink
Merge branch 'main' into snow-802269-func5-gm
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-gmahadevan authored Aug 27, 2024
2 parents 2b2c576 + 3ab8a52 commit 1851988
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 71 deletions.
6 changes: 0 additions & 6 deletions fips-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
<scalaPluginVersion>4.3.0</scalaPluginVersion>
<jackson.version>2.13.2</jackson.version>
<jackson.databind.version>2.13.4.2</jackson.databind.version>
<jackson.module.scala.version>2.13.5</jackson.module.scala.version>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -145,11 +144,6 @@
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.12</artifactId>
<version>${jackson.module.scala.version}</version>
</dependency>

<!-- Test -->
<dependency>
Expand Down
6 changes: 0 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
<scalaPluginVersion>4.3.0</scalaPluginVersion>
<jackson.version>2.13.2</jackson.version>
<jackson.databind.version>2.13.4.2</jackson.databind.version>
<jackson.module.scala.version>2.13.5</jackson.module.scala.version>

</properties>
<dependencyManagement>
Expand Down Expand Up @@ -133,11 +132,6 @@
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.12</artifactId>
<version>${jackson.module.scala.version}</version>
</dependency>

<!-- Test -->
<dependency>
Expand Down
13 changes: 2 additions & 11 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package com.snowflake.snowpark

import com.fasterxml.jackson.databind.json.JsonMapper
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}

import java.io.{File, FileInputStream, FileNotFoundException}
import java.net.URI
import java.sql.{Connection, Date, Time, Timestamp}
Expand All @@ -29,7 +26,6 @@ import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, Snowfl
import scala.concurrent.{ExecutionContext, Future}
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

/**
*
Expand Down Expand Up @@ -65,11 +61,6 @@ import scala.util.Try
* @since 0.1.0
*/
class Session private (private[snowpark] val conn: ServerConnection) extends Logging {
private val jsonMapper = JsonMapper
.builder()
.addModule(DefaultScalaModule)
.build() :: ClassTagExtensions

private val STAGE_PREFIX = "@"
// URI and file name with md5
private val classpathURIs = new ConcurrentHashMap[URI, Option[String]]().asScala
Expand Down Expand Up @@ -397,7 +388,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* successful, or `None` otherwise.
*/
private def parseJsonString(jsonString: String): Option[Map[String, Any]] = {
Try(jsonMapper.readValue[Map[String, Any]](jsonString)).toOption
Utils.jsonToMap(jsonString)
}

/**
Expand All @@ -408,7 +399,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* or `None` otherwise.
*/
private def toJsonString(map: Map[String, Any]): Option[String] = {
Try(jsonMapper.writeValueAsString(map)).toOption
Utils.mapToJson(map)
}

/*
Expand Down
62 changes: 22 additions & 40 deletions src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,51 +60,33 @@ object OpenTelemetry extends Logging {
execName: String,
execHandler: String,
execFilePath: String)(func: => T): T = {
try {
spanInfo.withValue[T](spanInfo.value match {
// empty info means this is the entry of the recursion
case None =>
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
Some(
UdfInfo(
className,
funcName,
fileName,
lineNumber,
execName,
execHandler,
execFilePath))
// if value is not empty, this function call should be recursion.
// do not issue new SpanInfo, use the info inherited from previous.
case other => other
}) {
val result: T = func
OpenTelemetry.emit(spanInfo.value.get)
result
}
} catch {
case error: Throwable =>
OpenTelemetry.reportError(className, funcName, error)
throw error
}
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
val newSpan =
UdfInfo(className, funcName, fileName, lineNumber, execName, execHandler, execFilePath)
emitSpan(newSpan, className, funcName, func)
}
// wrapper of all action functions
def action[T](className: String, funcName: String, methodChain: String)(func: => T): T = {
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
val newInfo =
ActionInfo(className, funcName, fileName, lineNumber, s"$methodChain.$funcName")
emitSpan(newInfo, className, funcName, func)
}

private def emitSpan[T](span: SpanInfo, className: String, funcName: String, thunk: => T): T = {
try {
spanInfo.withValue[T](spanInfo.value match {
// empty info means this is the entry of the recursion
spanInfo.value match {
case None =>
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
Some(ActionInfo(className, funcName, fileName, lineNumber, s"$methodChain.$funcName"))
// if value is not empty, this function call should be recursion.
// do not issue new SpanInfo, use the info inherited from previous.
case other => other
}) {
val result: T = func
OpenTelemetry.emit(spanInfo.value.get)
result
spanInfo.withValue(Some(span)) {
val result: T = thunk
// only emit one time, in the top level action
OpenTelemetry.emit(spanInfo.value.get)
result
}
case _ =>
thunk
}
} catch {
case error: Throwable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.snowflake.snowpark.internal
import com.fasterxml.jackson.annotation.JsonView
import com.fasterxml.jackson.core.TreeNode
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.module.scala.DefaultScalaModule

import java.io.File
import java.net.{URI, URLClassLoader}
Expand All @@ -25,8 +24,6 @@ object UDFClassPath extends Logging {
val jacksonDatabindClass: Class[JsonNode] = classOf[com.fasterxml.jackson.databind.JsonNode]
val jacksonCoreClass: Class[TreeNode] = classOf[com.fasterxml.jackson.core.TreeNode]
val jacksonAnnotationClass: Class[JsonView] = classOf[com.fasterxml.jackson.annotation.JsonView]
val jacksonModuleScalaClass: Class[DefaultScalaModule] =
classOf[com.fasterxml.jackson.module.scala.DefaultScalaModule]
val jacksonJarSeq = Seq(
RequiredLibrary(
getPathForClass(jacksonDatabindClass),
Expand All @@ -36,11 +33,7 @@ object UDFClassPath extends Logging {
RequiredLibrary(
getPathForClass(jacksonAnnotationClass),
"jackson-annotation",
jacksonAnnotationClass),
RequiredLibrary(
getPathForClass(jacksonModuleScalaClass),
"jackson-module-scala",
jacksonModuleScalaClass))
jacksonAnnotationClass))

/*
* Libraries required to compile java code generated by snowpark for user's lambda.
Expand Down
68 changes: 68 additions & 0 deletions src/main/scala/com/snowflake/snowpark/internal/Utils.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.snowflake.snowpark.internal

import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
import com.fasterxml.jackson.databind.node.JsonNodeType
import com.snowflake.snowpark.Column
import com.snowflake.snowpark.internal.analyzer.{
Attribute,
Expand All @@ -15,6 +17,7 @@ import java.util.Locale
import com.snowflake.snowpark.udtf.UDTF
import net.snowflake.client.jdbc.SnowflakeSQLException

import scala.collection.JavaConverters.{asScalaIteratorConverter, mapAsScalaMapConverter}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
Expand Down Expand Up @@ -447,4 +450,69 @@ object Utils extends Logging {
case _ => throw ErrorMessage.DF_JOIN_WITH_WRONG_ARGUMENT()
}
}

private val objectMapper = new ObjectMapper()

private[snowpark] def jsonToMap(jsonString: String): Option[Map[String, Any]] = {
try {
val node = objectMapper.readTree(jsonString)
assert(node.getNodeType == JsonNodeType.OBJECT)
Some(jsonToScala(node).asInstanceOf[Map[String, Any]])
} catch {
case ex: Exception =>
logError(ex.getMessage)
None
}
}

private def jsonToScala(node: JsonNode): Any = {
node.getNodeType match {
case JsonNodeType.STRING => node.asText()
case JsonNodeType.NULL => null
case JsonNodeType.OBJECT =>
node
.fields()
.asScala
.map(entry => {
entry.getKey -> jsonToScala(entry.getValue)
})
.toMap
case JsonNodeType.ARRAY =>
node.elements().asScala.map(entry => jsonToScala(entry)).toSeq
case JsonNodeType.BOOLEAN => node.asBoolean()
case JsonNodeType.NUMBER => node.numberValue()
case other =>
throw new UnsupportedOperationException(s"Unsupported Type: ${other.name()}")
}
}

private[snowpark] def mapToJson(map: Map[String, Any]): Option[String] = {
try {
Some(scalaToJson(map))
} catch {
case ex: Exception =>
logError(ex.getMessage)
None
}
}

private def scalaToJson(input: Any): String =
input match {
case null => "null"
case str: String => s""""$str""""
case _: Int | _: Short | _: Long | _: Byte | _: Double | _: Float | _: Boolean =>
input.toString
case map: Map[String, _] =>
map
.map {
case (key, value) => s"${scalaToJson(key)}:${scalaToJson(value)}"
}
.mkString("{", ",", "}")
case seq: Seq[_] => seq.map(scalaToJson).mkString("[", ",", "]")
case arr: Array[_] => scalaToJson(arr.toSeq)
case list: java.util.List[_] => scalaToJson(list.toArray)
case map: java.util.Map[String, _] => scalaToJson(map.asScala.toMap)
case _ =>
throw new UnsupportedOperationException(s"Unsupported Type: ${input.getClass.getName}")
}
}
53 changes: 53 additions & 0 deletions src/test/scala/com/snowflake/snowpark/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.lang.{
}
import net.snowflake.client.jdbc.SnowflakeSQLException

import java.util
import scala.collection.mutable.ArrayBuffer

class UtilsSuite extends SNTestBase {
Expand Down Expand Up @@ -672,6 +673,58 @@ class UtilsSuite extends SNTestBase {
assert(Utils.quoteForOption("FALSE").equals("FALSE"))
assert(Utils.quoteForOption("abc").equals("'abc'"))
}

test("Scala and Json format transformation") {
val javaHashMap = new util.HashMap[String, String]() {
{
put("one", "1")
put("two", "2")
put("three", "3")
}
}
val map = Map(
"nullKey" -> null,
"integerKey" -> 42,
"shortKey" -> 123.toShort,
"longKey" -> 1234567890L,
"byteKey" -> 123.toByte,
"doubleKey" -> 3.1415926,
"floatKey" -> 3.14F,
"boolKey" -> false,
"javaListKey" -> new util.ArrayList[String](util.Arrays.asList("a", "b")),
"javaMapKey" -> javaHashMap,
"seqKey" -> Seq(1, 2, 3),
"arrayKey" -> Array(1, 2, 3),
"seqOfStringKey" -> Seq("1", "2", "3"),
"stringKey" -> "stringValue",
"nestedMap" -> Map("insideKey" -> "stringValue", "insideList" -> Seq(1, 2, 3)),
"nestedList" -> Seq(1, Map("nestedKey" -> "nestedValue"), Array(1, 2, 3)))
val jsonString = Utils.mapToJson(map)
val expected_string = "{" +
"\"floatKey\":3.14," +
"\"javaMapKey\":{" +
"\"one\":\"1\"," +
"\"two\":\"2\"," +
"\"three\":\"3\"}," +
"\"integerKey\":42," +
"\"nullKey\":null," +
"\"longKey\":1234567890," +
"\"byteKey\":123," +
"\"seqKey\":[1,2,3]," +
"\"nestedMap\":{\"insideKey\":\"stringValue\",\"insideList\":[1,2,3]}," +
"\"stringKey\":\"stringValue\"," +
"\"doubleKey\":3.1415926," +
"\"seqOfStringKey\":[\"1\",\"2\",\"3\"]," +
"\"nestedList\":[1,{\"nestedKey\":\"nestedValue\"},[1,2,3]]," +
"\"javaListKey\":[\"a\",\"b\"]," +
"\"arrayKey\":[1,2,3]," +
"\"boolKey\":false," +
"\"shortKey\":123}"
val readMap = Utils.jsonToMap(jsonString.getOrElse(""))
val transformedString = Utils.mapToJson(readMap.getOrElse(Map()))
assert(jsonString.getOrElse("").equals(expected_string))
assert(jsonString.equals(transformedString))
}
}

object LoggingTester extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled {
checkSpanError("snow.snowpark.ClassA1", "functionB1", error)
}

test("only emit span once in the nested actions") {
session.sql("select 1").count()
val l = testSpanExporter.getFinishedSpanItems
assert(l.size() == 1)
}

override def beforeAll: Unit = {
super.beforeAll
createStage(stageName1)
Expand Down

0 comments on commit 1851988

Please sign in to comment.