Skip to content

Commit

Permalink
Merge branch 'main' into snow-802269-func6-gm
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-gmahadevan authored Aug 26, 2024
2 parents 66c2975 + 3ab8a52 commit 410a4e9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 40 deletions.
62 changes: 22 additions & 40 deletions src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,51 +60,33 @@ object OpenTelemetry extends Logging {
execName: String,
execHandler: String,
execFilePath: String)(func: => T): T = {
try {
spanInfo.withValue[T](spanInfo.value match {
// empty info means this is the entry of the recursion
case None =>
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
Some(
UdfInfo(
className,
funcName,
fileName,
lineNumber,
execName,
execHandler,
execFilePath))
// if value is not empty, this function call should be recursion.
// do not issue new SpanInfo, use the info inherited from previous.
case other => other
}) {
val result: T = func
OpenTelemetry.emit(spanInfo.value.get)
result
}
} catch {
case error: Throwable =>
OpenTelemetry.reportError(className, funcName, error)
throw error
}
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
val newSpan =
UdfInfo(className, funcName, fileName, lineNumber, execName, execHandler, execFilePath)
emitSpan(newSpan, className, funcName, func)
}
// wrapper of all action functions
def action[T](className: String, funcName: String, methodChain: String)(func: => T): T = {
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
val newInfo =
ActionInfo(className, funcName, fileName, lineNumber, s"$methodChain.$funcName")
emitSpan(newInfo, className, funcName, func)
}

private def emitSpan[T](span: SpanInfo, className: String, funcName: String, thunk: => T): T = {
try {
spanInfo.withValue[T](spanInfo.value match {
// empty info means this is the entry of the recursion
spanInfo.value match {
case None =>
val stacks = Thread.currentThread().getStackTrace
val (fileName, lineNumber) = findLineNumber(stacks)
Some(ActionInfo(className, funcName, fileName, lineNumber, s"$methodChain.$funcName"))
// if value is not empty, this function call should be recursion.
// do not issue new SpanInfo, use the info inherited from previous.
case other => other
}) {
val result: T = func
OpenTelemetry.emit(spanInfo.value.get)
result
spanInfo.withValue(Some(span)) {
val result: T = thunk
// only emit one time, in the top level action
OpenTelemetry.emit(spanInfo.value.get)
result
}
case _ =>
thunk
}
} catch {
case error: Throwable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,12 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled {
checkSpanError("snow.snowpark.ClassA1", "functionB1", error)
}

test("only emit span once in the nested actions") {
session.sql("select 1").count()
val l = testSpanExporter.getFinishedSpanItems
assert(l.size() == 1)
}

override def beforeAll: Unit = {
super.beforeAll
createStage(stageName1)
Expand Down

0 comments on commit 410a4e9

Please sign in to comment.