diff --git a/build.sbt b/build.sbt index 8af44512..6fbc6932 100644 --- a/build.sbt +++ b/build.sbt @@ -194,6 +194,7 @@ lazy val lambdaOtel4s = crossProject(JSPlatform, JVMPlatform) libraryDependencies ++= Seq( "org.typelevel" %%% "otel4s-core-trace" % otel4sVersion, "org.typelevel" %%% "otel4s-semconv" % otel4sVersion, + "org.typelevel" %%% "otel4s-sdk-trace-testkit" % otel4sVersion % Test, "org.scalameta" %%% "munit-scalacheck" % munitVersion % Test, "org.typelevel" %%% "munit-cats-effect-3" % munitCEVersion % Test ) diff --git a/lambda-otel4s/js/src/test/scala/feral/lambda/otel4s/TracedHandlerSuite.scala b/lambda-otel4s/js/src/test/scala/feral/lambda/otel4s/TracedHandlerSuite.scala new file mode 100644 index 00000000..d2b8e087 --- /dev/null +++ b/lambda-otel4s/js/src/test/scala/feral/lambda/otel4s/TracedHandlerSuite.scala @@ -0,0 +1,142 @@ +package feral.lambda +package otel4s + +import cats.effect.IO +import cats.effect.kernel.Resource +import cats.syntax.all._ +import feral.lambda.IOLambda +import io.circe.Decoder +import io.circe.Encoder +import io.circe.scalajs._ +import munit.CatsEffectSuite +import org.typelevel.otel4s.Attribute +import org.typelevel.otel4s.sdk.testkit.trace.TracesTestkit +import org.typelevel.otel4s.trace.SpanKind + +import java.util.concurrent.atomic.AtomicInteger +import scala.scalajs.js + +class TracedHandlerSuite extends CatsEffectSuite { + import TracedHandlerSuite._ + + val fixture = ResourceFixture(TracesTestkit.inMemory[IO]()) + + fixture.test("single root span is created for single invocation") { traces => + traces.tracerProvider.tracer("test-tracer").get.flatMap { implicit tracer => + val allocationCounter = new AtomicInteger + val invokeCounter = new AtomicInteger + + val lambda = new IOLambda[TestEvent, String] { + def handler = + Resource.eval(IO(allocationCounter.getAndIncrement())).as { implicit inv => + def fn(ev: TestEvent): IO[Option[String]] = + for { + _ <- IO(invokeCounter.getAndIncrement()) + res = Some(ev.payload) + } yield res + + TracedHandler(fn) + } + } + + val event = TestEvent("1", "body") + + val functionName = "test-function-name" + val run = IO.fromPromise(IO(lambda.handlerFn(event.asJsAny, DummyContext(functionName)))) + + for { + res <- run + spans <- traces.finishedSpans + _ <- IO { + assertEquals(res, "body".toString.asInstanceOf[js.UndefOr[js.Any]]) + assertEquals(spans.length, 1) + assertEquals(spans.headOption.map(_.name), Some(functionName)) + } + } yield () + } + + } + + fixture.test("multiple root span per invocation created with function name ") { traces => + traces.tracerProvider.tracer("test-tracer").get.flatMap { implicit tracer => + val allocationCounter = new AtomicInteger + val invokeCounter = new AtomicInteger + + val lambda = new IOLambda[TestEvent, String] { + def handler = + Resource.eval(IO(allocationCounter.getAndIncrement())).as { implicit inv => + def fn(ev: TestEvent): IO[Option[String]] = + for { + _ <- IO(invokeCounter.getAndIncrement()) + res = Some(ev.payload) + } yield res + + TracedHandler(fn) + } + } + + val functionName = "test-function-name" + val chars = 'A'.to('C').toList + val run = + chars.zipWithIndex.map { case (c, i) => TestEvent(i.toString, c.toString) }.traverse { + e => IO.fromPromise(IO(lambda.handlerFn(e.asJsAny, DummyContext(functionName)))) + } + + val expectedSpanNames = List.fill(3)(functionName) + + for { + res <- run + spans <- traces.finishedSpans + _ <- IO { + assertEquals(res.length, chars.length) + assertEquals(spans.length, chars.length) + assertEquals(spans.map(_.name), expectedSpanNames) + } + } yield () + } + } + + object DummyContext { + def apply(fnName: String): facade.Context = new facade.Context { + def functionName = fnName + def functionVersion = "" + def invokedFunctionArn = "" + def memoryLimitInMB = "0" + def awsRequestId = "" + def logGroupName = "" + def logStreamName = "" + def identity = js.undefined + def clientContext = js.undefined + def getRemainingTimeInMillis(): Double = 0 + } + } + +} + +object TracedHandlerSuite { + + case class TestEvent(traceId: String, payload: String) + + object TestEvent { + + implicit val decoder: Decoder[TestEvent] = + Decoder.forProduct2("traceId", "payload")(TestEvent.apply) + implicit val encoder: Encoder[TestEvent] = + Encoder.forProduct2("traceId", "payload")(ev => (ev.traceId, ev.payload)) + + implicit val attr: EventSpanAttributes[TestEvent] = + new EventSpanAttributes[TestEvent] { + + override def contextCarrier(e: TestEvent): Map[String, String] = + Map("trace_id" -> e.traceId) + + override def spanKind: SpanKind = SpanKind.Consumer + + override def attributes(e: TestEvent): List[Attribute[_]] = List.empty + + } + } + + def tracedLambda(allocationCounter: AtomicInteger, invokeCounter: AtomicInteger) = {} + +} diff --git a/lambda-otel4s/shared/src/main/scala/feral/lambda/otel4s/TracedHandler.scala b/lambda-otel4s/shared/src/main/scala/feral/lambda/otel4s/TracedHandler.scala index 6bc0d710..cca2659e 100644 --- a/lambda-otel4s/shared/src/main/scala/feral/lambda/otel4s/TracedHandler.scala +++ b/lambda-otel4s/shared/src/main/scala/feral/lambda/otel4s/TracedHandler.scala @@ -20,6 +20,8 @@ import cats.Monad import cats.syntax.all._ import feral.lambda.Invocation import org.typelevel.otel4s.trace.Tracer +import org.typelevel.otel4s.trace.SpanOps +import feral.lambda.Context object TracedHandler { @@ -33,19 +35,21 @@ object TracedHandler { event <- inv.event context <- inv.context res <- Tracer[F].joinOrRoot(attr.contextCarrier(event)) { - val spanR = - Tracer[F] - .spanBuilder(context.functionName) - .addAttributes(LambdaContextTraceAttributes(context)) - .withSpanKind(attr.spanKind) - .addAttributes(attr.attributes(event)) - .build - - spanR.surround { + buildSpan(event, context).surround { for { res <- handler(event) } yield res } } } yield res + + def buildSpan[F[_]: Tracer, Event](event: Event, context: Context[F])( + implicit attr: EventSpanAttributes[Event] + ): SpanOps[F] = + Tracer[F] + .spanBuilder(context.functionName) + .addAttributes(LambdaContextTraceAttributes(context)) + .withSpanKind(attr.spanKind) + .addAttributes(attr.attributes(event)) + .build }