From 3ab8a52cea024ba1ebfa5f6468ecd6d5e83d6d80 Mon Sep 17 00:00:00 2001 From: Bing Li <63471091+sfc-gh-bli@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:37:54 -0700 Subject: [PATCH] SNOW-1638551 Only Emit Span Once in the Nested Actions (#151) only emit span once in the nested actions --- .../snowpark/internal/OpenTelemetry.scala | 62 +++++++------------ .../snowpark_test/OpenTelemetrySuite.scala | 6 ++ 2 files changed, 28 insertions(+), 40 deletions(-) diff --git a/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala b/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala index 82108d43..8a219847 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/OpenTelemetry.scala @@ -60,51 +60,33 @@ object OpenTelemetry extends Logging { execName: String, execHandler: String, execFilePath: String)(func: => T): T = { - try { - spanInfo.withValue[T](spanInfo.value match { - // empty info means this is the entry of the recursion - case None => - val stacks = Thread.currentThread().getStackTrace - val (fileName, lineNumber) = findLineNumber(stacks) - Some( - UdfInfo( - className, - funcName, - fileName, - lineNumber, - execName, - execHandler, - execFilePath)) - // if value is not empty, this function call should be recursion. - // do not issue new SpanInfo, use the info inherited from previous. - case other => other - }) { - val result: T = func - OpenTelemetry.emit(spanInfo.value.get) - result - } - } catch { - case error: Throwable => - OpenTelemetry.reportError(className, funcName, error) - throw error - } + val stacks = Thread.currentThread().getStackTrace + val (fileName, lineNumber) = findLineNumber(stacks) + val newSpan = + UdfInfo(className, funcName, fileName, lineNumber, execName, execHandler, execFilePath) + emitSpan(newSpan, className, funcName, func) } // wrapper of all action functions def action[T](className: String, funcName: String, methodChain: String)(func: => T): T = { + val stacks = Thread.currentThread().getStackTrace + val (fileName, lineNumber) = findLineNumber(stacks) + val newInfo = + ActionInfo(className, funcName, fileName, lineNumber, s"$methodChain.$funcName") + emitSpan(newInfo, className, funcName, func) + } + + private def emitSpan[T](span: SpanInfo, className: String, funcName: String, thunk: => T): T = { try { - spanInfo.withValue[T](spanInfo.value match { - // empty info means this is the entry of the recursion + spanInfo.value match { case None => - val stacks = Thread.currentThread().getStackTrace - val (fileName, lineNumber) = findLineNumber(stacks) - Some(ActionInfo(className, funcName, fileName, lineNumber, s"$methodChain.$funcName")) - // if value is not empty, this function call should be recursion. - // do not issue new SpanInfo, use the info inherited from previous. - case other => other - }) { - val result: T = func - OpenTelemetry.emit(spanInfo.value.get) - result + spanInfo.withValue(Some(span)) { + val result: T = thunk + // only emit one time, in the top level action + OpenTelemetry.emit(spanInfo.value.get) + result + } + case _ => + thunk } } catch { case error: Throwable => diff --git a/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala b/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala index 4e8f6d58..15118fd3 100644 --- a/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/OpenTelemetrySuite.scala @@ -440,6 +440,12 @@ class OpenTelemetrySuite extends OpenTelemetryEnabled { checkSpanError("snow.snowpark.ClassA1", "functionB1", error) } + test("only emit span once in the nested actions") { + session.sql("select 1").count() + val l = testSpanExporter.getFinishedSpanItems + assert(l.size() == 1) + } + override def beforeAll: Unit = { super.beforeAll createStage(stageName1)